From ca9b21bb30e1d287f72a9e0febc6f007c545268b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 4 Jul 2020 13:41:19 +0000 Subject: [PATCH 001/207] Add docker-compose to support running test --- test.sh | 14 ++++++++++++++ test_docker_compose.yaml | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 test.sh create mode 100644 test_docker_compose.yaml diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..78928fea --- /dev/null +++ b/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +docker-compose -f test_docker_compose.yaml up -d + +export ORM_DRIVER=mysql +export TZ=UTC +export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + +go test ./... + +# clear all container +docker-compose -f test_docker_compose.yaml down + + diff --git a/test_docker_compose.yaml b/test_docker_compose.yaml new file mode 100644 index 00000000..54ca4097 --- /dev/null +++ b/test_docker_compose.yaml @@ -0,0 +1,39 @@ +version: "3.8" +services: + redis: + container_name: "beego-redis" + image: redis + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379" + + mysql: + container_name: "beego-mysql" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - ALLOW_EMPTY_PASSWORD=yes + ssdb: + container_name: "beego-ssdb" + image: wendal/ssdb + ports: + - "8888:8888" + memcache: + container_name: "beego-memcache" + image: memcached + ports: + - "11211:11211" + From 7c575585e9e305a28cc138d778f3f0bef5a6909a Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 15:27:12 +0100 Subject: [PATCH 002/207] added conditional json flag when trying to view healthchecks --- admin.go | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/admin.go b/admin.go index 3e538a0e..c5ae686d 100644 --- a/admin.go +++ b/admin.go @@ -279,9 +279,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - - rw.Header().Set("Content-Type", "application/json") - rw.Write(dataJSON) + execJSON(rw, dataJSON) return } @@ -295,7 +293,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { // Healthcheck is a http.Handler calling health checking and showing the result. // it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, _ *http.Request) { +func healthcheck(rw http.ResponseWriter, r *http.Request) { var ( result []string data = make(map[interface{}]interface{}) @@ -322,12 +320,44 @@ func healthcheck(rw http.ResponseWriter, _ *http.Request) { *resultList = append(*resultList, result) } + queryParams := r.URL.Query() + + if queryParams["json"] != nil { + + type Result map[string]interface{} + + response := make([]Result, len(*resultList)) + + for i, currentResult := range *resultList { + currentResultMap := make(Result) + currentResultMap["name"] = currentResult[0] + currentResultMap["message"] = currentResult[1] + currentResultMap["status"] = currentResult[2] + response[i] = currentResultMap + } + + JSONResponse, err := json.Marshal(response) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + execJSON(rw, JSONResponse) + } + + return + } + content["Data"] = resultList data["Content"] = content data["Title"] = "Health Check" + execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) } +func execJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + // TaskStatus is a http.Handler with running task status (task name, status and the last execution). // it's in "/task" pattern in admin module. func taskStatus(rw http.ResponseWriter, req *http.Request) { From db547a7c84aa7957b65b84d33f92d79893d3c7ae Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 16:04:29 +0100 Subject: [PATCH 003/207] added test for execJson --- admin_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/admin_test.go b/admin_test.go index 71cc209e..b7a9af9c 100644 --- a/admin_test.go +++ b/admin_test.go @@ -1,7 +1,9 @@ package beego import ( + "encoding/json" "fmt" + "net/http/httptest" "testing" ) @@ -75,3 +77,27 @@ func oldMap() M { m["BConfig.Log.Outputs"] = BConfig.Log.Outputs return m } + +func TestExecJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + execJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Should be able to decode response body into decodedBody slice") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} From 8d1a9bc92e758e6e174076fcae16ccebb4bc7fba Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 19:34:48 +0100 Subject: [PATCH 004/207] added tests for health check endpoints --- admin_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/admin_test.go b/admin_test.go index b7a9af9c..3875a4bb 100644 --- a/admin_test.go +++ b/admin_test.go @@ -2,11 +2,30 @@ package beego import ( "encoding/json" + "errors" "fmt" + "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/astaxie/beego/toolbox" ) +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + func TestList_01(t *testing.T) { m := make(M) list("BConfig", BConfig, m) @@ -101,3 +120,57 @@ func TestExecJSON(t *testing.T) { } } } + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + expectedResponseBody := `[{"message":"database","name":"success","status":"OK"},{"message":"cache","name":"error","status":"no cache detected"}]` + if w.Body.String() != expectedResponseBody { + t.Errorf("handler returned unexpected body: got %v want %v", + w.Body.String(), expectedResponseBody) + } +} From fc56c562dbadf1a4a35487bd036759d4acb90462 Mon Sep 17 00:00:00 2001 From: Gabriel Cruz Date: Mon, 6 Jul 2020 20:35:56 +0200 Subject: [PATCH 005/207] Fix logger reconnection --- logs/conn.go | 1 - logs/conn_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/logs/conn.go b/logs/conn.go index afe0cbb7..5201f30e 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -101,7 +101,6 @@ func (c *connWriter) connect() error { func (c *connWriter) needToConnectOnMsg() bool { if c.Reconnect { - c.Reconnect = false return true } diff --git a/logs/conn_test.go b/logs/conn_test.go index 747fb890..bb377d41 100644 --- a/logs/conn_test.go +++ b/logs/conn_test.go @@ -15,11 +15,65 @@ package logs import ( + "net" + "os" "testing" ) +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + func TestConn(t *testing.T) { log := NewLogger(1000) log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) log.Informational("informational") } + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} From d8724cb122327327784f182f23cba47817d5e1b8 Mon Sep 17 00:00:00 2001 From: Gabriel Cruz Date: Mon, 6 Jul 2020 21:34:09 +0200 Subject: [PATCH 006/207] Add error returning to writeln --- logs/conn.go | 5 ++++- logs/logger.go | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/logs/conn.go b/logs/conn.go index 5201f30e..74c458ab 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -63,7 +63,10 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { defer c.innerWriter.Close() } - c.lg.writeln(when, msg) + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } return nil } diff --git a/logs/logger.go b/logs/logger.go index c7cf8a56..a28bff6f 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -30,11 +30,12 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(when time.Time, msg string) { +func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { lg.Lock() h, _, _ := formatTimeHeader(when) - lg.writer.Write(append(append(h, msg...), '\n')) + n, err := lg.writer.Write(append(append(h, msg...), '\n')) lg.Unlock() + return n, err } const ( From 5a4a082af07dbecb996627e544e09fc32b097521 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 14:54:21 +0100 Subject: [PATCH 007/207] renamed functions for clarity --- admin.go | 22 +++++++++++----------- admin_test.go | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/admin.go b/admin.go index c5ae686d..cc9043e7 100644 --- a/admin.go +++ b/admin.go @@ -71,7 +71,7 @@ func init() { // AdminIndex is the default http.Handler for admin module. // it matches url pattern "/". func adminIndex(rw http.ResponseWriter, _ *http.Request) { - execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) } // QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. @@ -91,7 +91,7 @@ func qpsIndex(rw http.ResponseWriter, _ *http.Request) { } } - execTpl(rw, data, qpsTpl, defaultScriptsTpl) + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) } // ListConf is the http.Handler of displaying all beego configuration values as key/value pair. @@ -128,7 +128,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { } data["Content"] = content data["Title"] = "Routers" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) case "filter": var ( content = M{ @@ -171,7 +171,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { data["Content"] = content data["Title"] = "Filters" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) default: rw.Write([]byte("command not support")) } @@ -279,7 +279,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - execJSON(rw, dataJSON) + writeJSON(rw, dataJSON) return } @@ -288,7 +288,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { if command == "gc summary" { defaultTpl = gcAjaxTpl } - execTpl(rw, data, profillingTpl, defaultTpl) + writeTemplate(rw, data, profillingTpl, defaultTpl) } // Healthcheck is a http.Handler calling health checking and showing the result. @@ -340,7 +340,7 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } else { - execJSON(rw, JSONResponse) + writeJSON(rw, JSONResponse) } return @@ -350,10 +350,10 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { data["Content"] = content data["Title"] = "Health Check" - execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } -func execJSON(rw http.ResponseWriter, jsonData []byte) { +func writeJSON(rw http.ResponseWriter, jsonData []byte) { rw.Header().Set("Content-Type", "application/json") rw.Write(jsonData) } @@ -401,10 +401,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { content["Data"] = resultList data["Content"] = content data["Title"] = "Tasks" - execTpl(rw, data, tasksTpl, defaultScriptsTpl) + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) } -func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) for _, tpl := range tpls { tmpl = template.Must(tmpl.Parse(tpl)) diff --git a/admin_test.go b/admin_test.go index 3875a4bb..1bf8700a 100644 --- a/admin_test.go +++ b/admin_test.go @@ -97,7 +97,7 @@ func oldMap() M { return m } -func TestExecJSON(t *testing.T) { +func TestWriteJSON(t *testing.T) { t.Log("Testing the adding of JSON to the response") w := httptest.NewRecorder() @@ -105,7 +105,7 @@ func TestExecJSON(t *testing.T) { res, _ := json.Marshal(originalBody) - execJSON(w, res) + writeJSON(w, res) decodedBody := []int{} err := json.NewDecoder(w.Body).Decode(&decodedBody) From ca0c64b69e3357ff8ce94a0e52f445bdf471d4b5 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 15:21:38 +0100 Subject: [PATCH 008/207] refactored tests for health check endpoint --- admin_test.go | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/admin_test.go b/admin_test.go index 1bf8700a..4a614721 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -111,7 +112,7 @@ func TestWriteJSON(t *testing.T) { err := json.NewDecoder(w.Body).Decode(&decodedBody) if err != nil { - t.Fatal("Should be able to decode response body into decodedBody slice") + t.Fatal("Could not decode response body into slice.") } for i := range decodedBody { @@ -168,9 +169,31 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { status, http.StatusOK) } - expectedResponseBody := `[{"message":"database","name":"success","status":"OK"},{"message":"cache","name":"error","status":"no cache detected"}]` - if w.Body.String() != expectedResponseBody { + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { t.Errorf("handler returned unexpected body: got %v want %v", - w.Body.String(), expectedResponseBody) + decodedResponseBody, expectedResponseBody) } + } From 469dc7bea9b08add170177cc0a8615dbd0efa944 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:09:22 +0100 Subject: [PATCH 009/207] refactored the building of healthcheck response map --- admin.go | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/admin.go b/admin.go index cc9043e7..2ee415d4 100644 --- a/admin.go +++ b/admin.go @@ -21,6 +21,7 @@ import ( "net/http" "os" "reflect" + "strconv" "text/template" "time" @@ -321,28 +322,18 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { } queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) - if queryParams["json"] != nil { + if shouldReturnJSON { + responseMap := buildHealthCheckResponseMap(resultList) + jsonResponse, err := json.Marshal(responseMap) - type Result map[string]interface{} - - response := make([]Result, len(*resultList)) - - for i, currentResult := range *resultList { - currentResultMap := make(Result) - currentResultMap["name"] = currentResult[0] - currentResultMap["message"] = currentResult[1] - currentResultMap["status"] = currentResult[2] - response[i] = currentResultMap - } - - JSONResponse, err := json.Marshal(response) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } else { - writeJSON(rw, JSONResponse) + writeJSON(rw, jsonResponse) } - return } @@ -353,6 +344,23 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } +func buildHealthCheckResponseMap(resultList *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*resultList)) + + for i, currentResult := range *resultList { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = currentResult[0] + currentResultMap["message"] = currentResult[1] + currentResultMap["status"] = currentResult[2] + + response[i] = currentResultMap + } + + return response + +} + func writeJSON(rw http.ResponseWriter, jsonData []byte) { rw.Header().Set("Content-Type", "application/json") rw.Write(jsonData) From e0f8c6832d5477e7b239b60ce25db4b42f01ed49 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:28:16 +0100 Subject: [PATCH 010/207] added test for buildingHealthCheckResponse --- admin.go | 16 ++++++++-------- admin_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/admin.go b/admin.go index 2ee415d4..db52647e 100644 --- a/admin.go +++ b/admin.go @@ -326,8 +326,8 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) if shouldReturnJSON { - responseMap := buildHealthCheckResponseMap(resultList) - jsonResponse, err := json.Marshal(responseMap) + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) @@ -344,15 +344,15 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } -func buildHealthCheckResponseMap(resultList *[][]string) []map[string]interface{} { - response := make([]map[string]interface{}, len(*resultList)) +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) - for i, currentResult := range *resultList { + for i, healthCheckResult := range *healthCheckResults { currentResultMap := make(map[string]interface{}) - currentResultMap["name"] = currentResult[0] - currentResultMap["message"] = currentResult[1] - currentResultMap["status"] = currentResult[2] + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] response[i] = currentResultMap } diff --git a/admin_test.go b/admin_test.go index 4a614721..24da3a2b 100644 --- a/admin_test.go +++ b/admin_test.go @@ -149,6 +149,41 @@ func TestHealthCheckHandlerDefault(t *testing.T) { } +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) From 728bf340064803d5ace628c07f715a2532c1311c Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:46:59 +0100 Subject: [PATCH 011/207] refacted cache health check from toolbox --- admin_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/admin_test.go b/admin_test.go index 24da3a2b..d9a66b34 100644 --- a/admin_test.go +++ b/admin_test.go @@ -187,7 +187,6 @@ func TestBuildHealthCheckResponseList(t *testing.T) { func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) if err != nil { @@ -213,11 +212,6 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { "message":"database", "name":"success", "status":"OK" - }, - { - "message":"cache", - "name":"error", - "status":"no cache detected" } ] `) From d7b0d55357dce068c8b216fc49d024071f865f42 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 17:23:52 +0100 Subject: [PATCH 012/207] added extra check for same response lengths --- admin_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/admin_test.go b/admin_test.go index d9a66b34..3f3612e4 100644 --- a/admin_test.go +++ b/admin_test.go @@ -187,6 +187,7 @@ func TestBuildHealthCheckResponseList(t *testing.T) { func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) if err != nil { @@ -212,6 +213,11 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { "message":"database", "name":"success", "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" } ] `) @@ -220,6 +226,11 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { t.Errorf("handler returned unexpected body: got %v want %v", decodedResponseBody, expectedResponseBody) From 946a42c021f3688ca411315b71fe989152df7f27 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Wed, 8 Jul 2020 17:14:52 +0800 Subject: [PATCH 013/207] fix: response http 413 when body size larger then MaxMemory. --- router.go | 4 ++++ router_test.go | 29 +++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index e71366b4..9b257391 100644 --- a/router.go +++ b/router.go @@ -707,6 +707,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if r.Method != http.MethodGet && r.Method != http.MethodHead { if BConfig.CopyRequestBody && !context.Input.IsUpload() { + if r.ContentLength > BConfig.MaxMemory { + exception("413", context) + goto Admin + } context.Input.CopyBody(BConfig.MaxMemory) } context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) diff --git a/router_test.go b/router_test.go index 2797b33a..e49f38db 100644 --- a/router_test.go +++ b/router_test.go @@ -15,6 +15,7 @@ package beego import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -71,7 +72,6 @@ func (tc *TestController) GetEmptyBody() { tc.Ctx.Output.Body(res) } - type JSONController struct { Controller } @@ -656,17 +656,14 @@ func beegoBeforeRouter1(ctx *context.Context) { ctx.WriteString("|BeforeRouter1") } - func beegoBeforeExec1(ctx *context.Context) { ctx.WriteString("|BeforeExec1") } - func beegoAfterExec1(ctx *context.Context) { ctx.WriteString("|AfterExec1") } - func beegoFinishRouter1(ctx *context.Context) { ctx.WriteString("|FinishRouter1") } @@ -709,3 +706,27 @@ func TestYAMLPrepare(t *testing.T) { t.Errorf(w.Body.String()) } } + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != 413 { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} From 03f78b2e4a5353173d96cbfbe19a3195cdc032f4 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Wed, 8 Jul 2020 18:09:01 +0800 Subject: [PATCH 014/207] fix: add error code support --- error.go | 15 ++++++++++++++- hooks.go | 1 + 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/error.go b/error.go index e5e9fd47..0b148974 100644 --- a/error.go +++ b/error.go @@ -28,7 +28,7 @@ import ( ) const ( - errorTypeHandler = iota + errorTypeHandler = iota errorTypeController ) @@ -359,6 +359,19 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { ) } +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The request entity is larger than limits defined by server"+ + "
    Please change the request entity and try again."+ + "
", + ) +} + func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := M{ diff --git a/hooks.go b/hooks.go index b8671d35..49c42d5a 100644 --- a/hooks.go +++ b/hooks.go @@ -34,6 +34,7 @@ func registerDefaultErrorHandler() error { "504": gatewayTimeout, "417": invalidxsrf, "422": missingxsrf, + "413": payloadTooLarge, } for e, h := range m { if _, ok := ErrorMaps[e]; !ok { From c08b27111ca04110da9e92dd5541c25bf712cc47 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 8 Jul 2020 22:50:03 +0800 Subject: [PATCH 015/207] Fix 4059 --- orm/db_alias.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index cf6a5935..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -425,7 +425,6 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { type stmtDecorator struct { wg sync.WaitGroup - lastUse int64 stmt *sql.Stmt } @@ -433,9 +432,12 @@ func (s *stmtDecorator) getStmt() *sql.Stmt { return s.stmt } +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this func (s *stmtDecorator) acquire() { s.wg.Add(1) - s.lastUse = time.Now().Unix() } func (s *stmtDecorator) release() { @@ -453,7 +455,6 @@ func (s *stmtDecorator) destroy() { func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { return &stmtDecorator{ stmt: sqlStmt, - lastUse: time.Now().Unix(), } } From 2eccb23461ccaa545564c865de6191ac2c9dc9e6 Mon Sep 17 00:00:00 2001 From: Cathal Date: Thu, 2 Jul 2020 20:48:43 +0100 Subject: [PATCH 016/207] Add sleep on reconnect functionality --- httplib/httplib.go | 8 ++++++++ httplib/httplib_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/httplib/httplib.go b/httplib/httplib.go index e094a6a6..60aa4e8b 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -144,6 +144,7 @@ type BeegoHTTPSettings struct { Gzip bool DumpBody bool Retries int // if set to -1 means will retry forever + RetryDelay time.Duration } // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. @@ -202,6 +203,11 @@ func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { return b } +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + // DumpBody setting whether need to Dump the Body. func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump @@ -512,11 +518,13 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { break } + time.Sleep(b.setting.RetryDelay) } return resp, err } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index dd2a4f1c..f6be8571 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "errors" "io/ioutil" "net" "net/http" @@ -33,6 +34,34 @@ func TestResponse(t *testing.T) { t.Log(resp) } +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + func TestGet(t *testing.T) { req := Get("http://httpbin.org/get") b, err := req.Bytes() From c3f14a0ad6ee54f0e474a220a8eb81a67a6b6335 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Thu, 9 Jul 2020 09:45:40 +0800 Subject: [PATCH 017/207] refactor: log error when payload too large --- error.go | 13 +++++++------ router.go | 2 ++ router_test.go | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/error.go b/error.go index 0b148974..f268f723 100644 --- a/error.go +++ b/error.go @@ -363,12 +363,13 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { responseError(rw, r, 413, - "
The page you have requested is unavailable."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The request entity is larger than limits defined by server"+ - "
    Please change the request entity and try again."+ - "
", + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, ) } diff --git a/router.go b/router.go index 9b257391..d910deb0 100644 --- a/router.go +++ b/router.go @@ -707,7 +707,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if r.Method != http.MethodGet && r.Method != http.MethodHead { if BConfig.CopyRequestBody && !context.Input.IsUpload() { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > BConfig.MaxMemory { + logs.Error(errors.New("payload too large")) exception("413", context) goto Admin } diff --git a/router_test.go b/router_test.go index e49f38db..8ec7927a 100644 --- a/router_test.go +++ b/router_test.go @@ -726,7 +726,7 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { BConfig.CopyRequestBody = _CopyRequestBody BConfig.MaxMemory = _MaxMemory - if w.Code != 413 { + if w.Code != http.StatusRequestEntityTooLarge { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } From 76debb1899c10ac69356497d543f854c6a5dfa9f Mon Sep 17 00:00:00 2001 From: Acmefocus <37472851+Acmefocus@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:18:01 +0800 Subject: [PATCH 018/207] Update README.md --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 3b414c6f..aacd237e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego +#### Create hello directory + + sudo mkdir hello + cd hello + +#### Init module + + go mod init + #### Create file `hello.go` ```go package main From 40cdc877b618279ba598fb550a64c2b0f0325f53 Mon Sep 17 00:00:00 2001 From: Acmefocus <37472851+Acmefocus@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:18:01 +0800 Subject: [PATCH 019/207] Update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 3b414c6f..aacd237e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego +#### Create hello directory + + sudo mkdir hello + cd hello + +#### Init module + + go mod init + #### Create file `hello.go` ```go package main From 25ba78ea7247a6d05356a3012f7f0df7ddc92633 Mon Sep 17 00:00:00 2001 From: Acmefocus <107723772@qq.com> Date: Thu, 9 Jul 2020 18:14:31 +0800 Subject: [PATCH 020/207] update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index aacd237e..a9acfcc1 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego -#### Create hello directory +#### Create `hello` directory, cd `hello` directory - sudo mkdir hello + mkdir hello cd hello - + #### Init module go mod init From 2b9aaa5b0d45a6a0a1009eea5575c3d62e8f4210 Mon Sep 17 00:00:00 2001 From: Acmefocus <107723772@qq.com> Date: Fri, 10 Jul 2020 09:58:06 +0800 Subject: [PATCH 021/207] update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a9acfcc1..de8f0063 100644 --- a/README.md +++ b/README.md @@ -8,18 +8,18 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature ## Quick Start -#### Download and install - - go get github.com/astaxie/beego - #### Create `hello` directory, cd `hello` directory - mkdir hello - cd hello + mkdir hello + cd hello #### Init module - go mod init + go mod init + +#### Download and install + + go get github.com/astaxie/beego #### Create file `hello.go` ```go From 5940ae33c2174806189419ec35b90a63454f649e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 13 Jul 2020 19:14:53 +0800 Subject: [PATCH 022/207] fix `index out of range` when sid len = 1 add unit test for sess_file.go --- session/sess_file.go | 6 +- session/sess_file_test.go | 387 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 session/sess_file_test.go diff --git a/session/sess_file.go b/session/sess_file.go index c6dbf209..f7a739cc 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -15,11 +15,11 @@ package session import ( + "errors" "fmt" "io/ioutil" "net/http" "os" - "errors" "path" "path/filepath" "strings" @@ -180,6 +180,10 @@ func (fp *FileProvider) SessionExist(sid string) bool { filepder.lock.Lock() defer filepder.lock.Unlock() + if len(sid) < 2 { + return false + } + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) return err == nil } diff --git a/session/sess_file_test.go b/session/sess_file_test.go new file mode 100644 index 00000000..0cf021db --- /dev/null +++ b/session/sess_file_test.go @@ -0,0 +1,387 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + + s.Set(i,i) + s.SessionRelease(nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(i).(int) != i { + t.Error() + } + } +} \ No newline at end of file From 678b90385b36334cb77cffe50e965b0ff13e7ee7 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 14 Jul 2020 09:57:13 +0800 Subject: [PATCH 023/207] add log --- session/sess_file.go | 1 + 1 file changed, 1 insertion(+) diff --git a/session/sess_file.go b/session/sess_file.go index f7a739cc..47ad54a7 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -181,6 +181,7 @@ func (fp *FileProvider) SessionExist(sid string) bool { defer filepder.lock.Unlock() if len(sid) < 2 { + SLogger.Println("min length of session id is 2", sid) return false } From ffe1d5212009bae069acfe3a2d02954c7acb1756 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 15 Jul 2020 10:04:22 +0800 Subject: [PATCH 024/207] Move orm to pkg/orm --- pkg/orm/README.md | 159 +++ pkg/orm/cmd.go | 283 +++++ pkg/orm/cmd_utils.go | 320 +++++ pkg/orm/db.go | 1902 +++++++++++++++++++++++++++++ pkg/orm/db_alias.go | 466 +++++++ pkg/orm/db_mysql.go | 183 +++ pkg/orm/db_oracle.go | 137 +++ pkg/orm/db_postgres.go | 189 +++ pkg/orm/db_sqlite.go | 161 +++ pkg/orm/db_tables.go | 482 ++++++++ pkg/orm/db_tidb.go | 63 + pkg/orm/db_utils.go | 177 +++ pkg/orm/models.go | 99 ++ pkg/orm/models_boot.go | 347 ++++++ pkg/orm/models_fields.go | 783 ++++++++++++ pkg/orm/models_info_f.go | 473 ++++++++ pkg/orm/models_info_m.go | 148 +++ pkg/orm/models_test.go | 497 ++++++++ pkg/orm/models_utils.go | 227 ++++ pkg/orm/orm.go | 579 +++++++++ pkg/orm/orm_conds.go | 153 +++ pkg/orm/orm_log.go | 222 ++++ pkg/orm/orm_object.go | 87 ++ pkg/orm/orm_querym2m.go | 140 +++ pkg/orm/orm_queryset.go | 300 +++++ pkg/orm/orm_raw.go | 867 +++++++++++++ pkg/orm/orm_test.go | 2494 ++++++++++++++++++++++++++++++++++++++ pkg/orm/qb.go | 62 + pkg/orm/qb_mysql.go | 185 +++ pkg/orm/qb_tidb.go | 182 +++ pkg/orm/types.go | 473 ++++++++ pkg/orm/utils.go | 319 +++++ pkg/orm/utils_test.go | 70 ++ 33 files changed, 13229 insertions(+) create mode 100644 pkg/orm/README.md create mode 100644 pkg/orm/cmd.go create mode 100644 pkg/orm/cmd_utils.go create mode 100644 pkg/orm/db.go create mode 100644 pkg/orm/db_alias.go create mode 100644 pkg/orm/db_mysql.go create mode 100644 pkg/orm/db_oracle.go create mode 100644 pkg/orm/db_postgres.go create mode 100644 pkg/orm/db_sqlite.go create mode 100644 pkg/orm/db_tables.go create mode 100644 pkg/orm/db_tidb.go create mode 100644 pkg/orm/db_utils.go create mode 100644 pkg/orm/models.go create mode 100644 pkg/orm/models_boot.go create mode 100644 pkg/orm/models_fields.go create mode 100644 pkg/orm/models_info_f.go create mode 100644 pkg/orm/models_info_m.go create mode 100644 pkg/orm/models_test.go create mode 100644 pkg/orm/models_utils.go create mode 100644 pkg/orm/orm.go create mode 100644 pkg/orm/orm_conds.go create mode 100644 pkg/orm/orm_log.go create mode 100644 pkg/orm/orm_object.go create mode 100644 pkg/orm/orm_querym2m.go create mode 100644 pkg/orm/orm_queryset.go create mode 100644 pkg/orm/orm_raw.go create mode 100644 pkg/orm/orm_test.go create mode 100644 pkg/orm/qb.go create mode 100644 pkg/orm/qb_mysql.go create mode 100644 pkg/orm/qb_tidb.go create mode 100644 pkg/orm/types.go create mode 100644 pkg/orm/utils.go create mode 100644 pkg/orm/utils_test.go diff --git a/pkg/orm/README.md b/pkg/orm/README.md new file mode 100644 index 00000000..6e808d2a --- /dev/null +++ b/pkg/orm/README.md @@ -0,0 +1,159 @@ +# beego orm + +[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) + +A powerful orm framework for go. + +It is heavily influenced by Django ORM, SQLAlchemy. + +**Support Database:** + +* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) +* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +Passed all test, but need more feedback. + +**Features:** + +* full go type support +* easy for usage, simple CRUD operation +* auto join with relation table +* cross DataBase compatible query +* Raw SQL query / mapper without orm model +* full test keep stable and strong + +more features please read the docs + +**Install:** + + go get github.com/astaxie/beego/orm + +## Changelog + +* 2013-08-19: support table auto create +* 2013-08-13: update test for database types +* 2013-08-13: go type support, such as int8, uint8, byte, rune +* 2013-08-13: date / datetime timezone support very well + +## Quick Start + +#### Simple Usage + +```go +package main + +import ( + "fmt" + "github.com/astaxie/beego/orm" + _ "github.com/go-sql-driver/mysql" // import your used driver +) + +// Model Struct +type User struct { + Id int `orm:"auto"` + Name string `orm:"size(100)"` +} + +func init() { + // register model + orm.RegisterModel(new(User)) + + // set default database + orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) + + // create table + orm.RunSyncdb("default", false, true) +} + +func main() { + o := orm.NewOrm() + + user := User{Name: "slene"} + + // insert + id, err := o.Insert(&user) + + // update + user.Name = "astaxie" + num, err := o.Update(&user) + + // read one + u := User{Id: user.Id} + err = o.Read(&u) + + // delete + num, err = o.Delete(&u) +} +``` + +#### Next with relation + +```go +type Post struct { + Id int `orm:"auto"` + Title string `orm:"size(100)"` + User *User `orm:"rel(fk)"` +} + +var posts []*Post +qs := o.QueryTable("post") +num, err := qs.Filter("User__Name", "slene").All(&posts) +``` + +#### Use Raw sql + +If you don't like ORM,use Raw SQL to query / mapping without ORM setting + +```go +var maps []Params +num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) +if num > 0 { + fmt.Println(maps[0]["id"]) +} +``` + +#### Transaction + +```go +o.Begin() +... +user := User{Name: "slene"} +id, err := o.Insert(&user) +if err == nil { + o.Commit() +} else { + o.Rollback() +} + +``` + +#### Debug Log Queries + +In development env, you can simple use + +```go +func main() { + orm.Debug = true +... +``` + +enable log queries. + +output include all queries, such as exec / prepare / transaction. + +like this: + +```go +[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` +... +``` + +note: not recommend use this in product env. + +## Docs + +more details and examples in docs and test + +[documents](http://beego.me/docs/mvc/model/overview.md) + diff --git a/pkg/orm/cmd.go b/pkg/orm/cmd.go new file mode 100644 index 00000000..0ff4dc40 --- /dev/null +++ b/pkg/orm/cmd.go @@ -0,0 +1,283 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "flag" + "fmt" + "os" + "strings" +) + +type commander interface { + Parse([]string) + Run() error +} + +var ( + commands = make(map[string]commander) +) + +// print help. +func printHelp(errs ...string) { + content := `orm command usage: + + syncdb - auto create tables + sqlall - print sql of create tables + help - print this help +` + + if len(errs) > 0 { + fmt.Println(errs[0]) + } + fmt.Println(content) + os.Exit(2) +} + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + if len(os.Args) < 2 || os.Args[1] != "orm" { + return + } + + BootStrap() + + args := argString(os.Args[2:]) + name := args.Get(0) + + if name == "help" { + printHelp() + } + + if cmd, ok := commands[name]; ok { + cmd.Parse(os.Args[3:]) + cmd.Run() + os.Exit(0) + } else { + if name == "" { + printHelp() + } else { + printHelp(fmt.Sprintf("unknown command %s", name)) + } + } +} + +// sync database struct command interface. +type commandSyncDb struct { + al *alias + force bool + verbose bool + noInfo bool + rtOnError bool +} + +// parse orm command line arguments. +func (d *commandSyncDb) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.BoolVar(&d.force, "force", false, "drop tables before create") + flagSet.BoolVar(&d.verbose, "v", false, "verbose info") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSyncDb) Run() error { + var drops []string + if d.force { + drops = getDbDropSQL(d.al) + } + + db := d.al.DB + + if d.force { + for i, mi := range modelCache.allOrdered() { + query := drops[i] + if !d.noInfo { + fmt.Printf("drop table `%s`\n", mi.table) + } + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + sqls, indexes := getDbCreateSQL(d.al) + + tables, err := d.al.DbBaser.GetTables(db) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for i, mi := range modelCache.allOrdered() { + if tables[mi.table] { + if !d.noInfo { + fmt.Printf("table `%s` already exists, skip\n", mi.table) + } + + var fields []*fieldInfo + columns, err := d.al.DbBaser.GetColumns(db, mi.table) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for _, fi := range mi.fields.fieldsDB { + if _, ok := columns[fi.column]; !ok { + fields = append(fields, fi) + } + } + + for _, fi := range fields { + query := getColumnAddQuery(d.al, fi) + + if !d.noInfo { + fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + } + + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + + for _, idx := range indexes[mi.table] { + if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.noInfo { + fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) + } + + query := idx.SQL + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + continue + } + + if !d.noInfo { + fmt.Printf("create table `%s` \n", mi.table) + } + + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + + for _, query := range queries { + _, err := db.Exec(query) + if d.verbose { + query = " " + strings.Join(strings.Split(query, "\n"), "\n ") + fmt.Println(query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + if d.verbose { + fmt.Println("") + } + } + + return nil +} + +// database creation commander interface implement. +type commandSQLAll struct { + al *alias +} + +// parse orm command line arguments. +func (d *commandSQLAll) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSQLAll) Run() error { + sqls, indexes := getDbCreateSQL(d.al) + var all []string + for i, mi := range modelCache.allOrdered() { + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + sql := strings.Join(queries, "\n") + all = append(all, sql) + } + fmt.Println(strings.Join(all, "\n\n")) + + return nil +} + +func init() { + commands["syncdb"] = new(commandSyncDb) + commands["sqlall"] = new(commandSQLAll) +} + +// RunSyncdb run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. +func RunSyncdb(name string, force bool, verbose bool) error { + BootStrap() + + al := getDbAlias(name) + cmd := new(commandSyncDb) + cmd.al = al + cmd.force = force + cmd.noInfo = !verbose + cmd.verbose = verbose + cmd.rtOnError = true + return cmd.Run() +} diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go new file mode 100644 index 00000000..61f17346 --- /dev/null +++ b/pkg/orm/cmd_utils.go @@ -0,0 +1,320 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "strings" +) + +type dbIndex struct { + Table string + Name string + SQL string +} + +// create database drop sql. +func getDbDropSQL(al *alias) (sqls []string) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return sqls +} + +// get database column type string. +func getColumnTyp(al *alias, fi *fieldInfo) (col string) { + T := al.DbBaser.DbTypes() + fieldType := fi.fieldType + fieldSize := fi.size + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeVarCharField: + if al.Driver == DRPostgres && fi.toText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } + case TypeCharField: + col = fmt.Sprintf(T["string-char"], fieldSize) + case TypeTextField: + col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + col = T["time.Time"] + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if al.Driver == DRSqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if !strings.Contains(s, "%d") { + col = s + } else { + col = fmt.Sprintf(s, fi.digits, fi.decimals) + } + case TypeJSONField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["jsonb"] + case RelForeignKey, RelOneToOne: + fieldType = fi.relModelInfo.fields.pk.fieldType + fieldSize = fi.relModelInfo.fields.pk.size + goto checkColumn + } + + return +} + +// create alter sql string. +func getColumnAddQuery(al *alias, fi *fieldInfo) string { + Q := al.DbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if !fi.null { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.mi.table, Q, + Q, fi.column, Q, + typ, getColumnDefault(fi), + ) +} + +// create database creation string. +func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + sqls = append(sqls, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + +// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands +func getColumnDefault(fi *fieldInfo) string { + var ( + v, t, d string + ) + + // Skip default attribute if field is in relations + if fi.rel || fi.reverse { + return v + } + + t = " DEFAULT '%s' " + + // These defaults will be useful if there no config value orm:"default" and NOT NULL is on + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, + TypeDecimalField: + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" + } + + if fi.colDefault { + if !fi.initial.Exist() { + v = fmt.Sprintf(t, "") + } else { + v = fmt.Sprintf(t, fi.initial.String()) + } + } else { + if !fi.null { + v = fmt.Sprintf(t, d) + } + } + + return v +} diff --git a/pkg/orm/db.go b/pkg/orm/db.go new file mode 100644 index 00000000..9a1827e8 --- /dev/null +++ b/pkg/orm/db.go @@ -0,0 +1,1902 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = errors.New("missed pk value") +) + +var ( + operators = map[string]bool{ + "exact": true, + "iexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "eq": true, + "nq": true, + "ne": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + "between": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, + } +) + +// an instance of dbBaser interface/ +type dbBase struct { + ins dbBaser +} + +// check dbBase implements dbBaser interface. +var _ dbBaser = new(dbBase) + +// get struct columns values as interface slice. +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { + if names == nil { + ns := make([]string, 0, len(cols)) + names = &ns + } + values = make([]interface{}, 0, len(cols)) + + for _, column := range cols { + var fi *fieldInfo + if fi, _ = mi.fields.GetByAny(column); fi != nil { + column = fi.column + } else { + panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + } + if !fi.dbcol || fi.auto && skipAuto { + continue + } + value, err := d.collectFieldValue(mi, fi, ind, insert, tz) + if err != nil { + return nil, nil, err + } + + // ignore empty value auto field + if insert && fi.auto { + if fi.fieldType&IsPositiveIntegerField > 0 { + if vu, ok := value.(uint64); !ok || vu == 0 { + continue + } + } else { + if vu, ok := value.(int64); !ok || vu == 0 { + continue + } + } + autoFields = append(autoFields, fi.column) + } + + *names, values = append(*names, column), append(values, value) + } + + return +} + +// get one field value in struct column as interface. +func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { + var value interface{} + if fi.pk { + _, value, _ = getExistPk(mi, ind) + } else { + field := ind.FieldByIndex(fi.fieldIndex) + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + value = f.RawValue() + } else { + switch fi.fieldType { + case TypeBooleanField: + if nb, ok := field.Interface().(sql.NullBool); ok { + value = nil + if nb.Valid { + value = nb.Bool + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Bool() + } + } else { + value = field.Bool() + } + case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: + if ns, ok := field.Interface().(sql.NullString); ok { + value = nil + if ns.Valid { + value = ns.String + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().String() + } + } else { + value = field.String() + } + case TypeFloatField, TypeDecimalField: + if nf, ok := field.Interface().(sql.NullFloat64); ok { + value = nil + if nf.Valid { + value = nf.Float64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Float() + } + } else { + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } + } + case TypeTimeField, TypeDateField, TypeDateTimeField: + value = field.Interface() + if t, ok := value.(time.Time); ok { + d.ins.TimeToDB(&t, tz) + if t.IsZero() { + value = nil + } else { + value = t + } + } + default: + switch { + case fi.fieldType&IsPositiveIntegerField > 0: + if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Uint() + } + } else { + value = field.Uint() + } + case fi.fieldType&IsIntegerField > 0: + if ni, ok := field.Interface().(sql.NullInt64); ok { + value = nil + if ni.Valid { + value = ni.Int64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Int() + } + } else { + value = field.Int() + } + case fi.fieldType&IsRelField > 0: + if field.IsNil() { + value = nil + } else { + if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { + value = vu + } else { + value = nil + } + } + if !fi.null && value == nil { + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) + } + } + } + } + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField: + if fi.autoNow || fi.autoNowAdd && insert { + if insert { + if t, ok := value.(time.Time); ok && !t.IsZero() { + break + } + } + tnow := time.Now() + d.ins.TimeToDB(&tnow, tz) + value = tnow + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + f.SetRaw(tnow.In(DefaultTimeLoc)) + } else if field.Kind() == reflect.Ptr { + v := tnow.In(DefaultTimeLoc) + field.Set(reflect.ValueOf(&v)) + } else { + field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) + } + } + case TypeJSONField, TypeJsonbField: + if s, ok := value.(string); (ok && len(s) == 0) || value == nil { + if fi.colDefault && fi.initial.Exist() { + value = fi.initial.String() + } else { + value = nil + } + } + } + } + return value, nil +} + +// create insert sql preparation statement object. +func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { + Q := d.ins.TableQuote() + + dbcols := make([]string, 0, len(mi.fields.dbcols)) + marks := make([]string, 0, len(mi.fields.dbcols)) + for _, fi := range mi.fields.fieldsDB { + if !fi.auto { + dbcols = append(dbcols, fi.column) + marks = append(marks, "?") + } + } + qmarks := strings.Join(marks, ", ") + sep := fmt.Sprintf("%s, %s", Q, Q) + columns := strings.Join(dbcols, sep) + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + d.ins.HasReturningID(mi, &query) + + stmt, err := q.Prepare(query) + return stmt, query, err +} + +// insert struct with prepared statement and given struct reflect value. +func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + if err != nil { + return 0, err + } + + if d.ins.HasReturningID(mi, nil) { + row := stmt.QueryRow(values...) + var id int64 + err := row.Scan(&id) + return id, err + } + res, err := stmt.Exec(values...) + if err == nil { + return res.LastInsertId() + } + return 0, err +} + +// query sql ,read records and persist in dbBaser. +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + var whereCols []string + var args []interface{} + + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s, %s", Q, Q) + sels := strings.Join(mi.fields.dbcols, sep) + colsNum := len(mi.fields.dbcols) + + sep = fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + forUpdate := "" + if isForUpdate { + forUpdate = "FOR UPDATE" + } + + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + d.ins.ReplaceMarks(&query) + + row := q.QueryRow(query, args...) + if err := row.Scan(refs...); err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + ind.Set(mind) + return nil +} + +// execute insert sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names := make([]string, 0, len(mi.fields.dbcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return 0, err + } + + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err +} + +// multi-insert sql with given slice struct reflect.Value. +func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { + var ( + cnt int64 + nums int + values []interface{} + names []string + ) + + // typ := reflect.Indirect(mi.addrField).Type() + + length, autoFields := sind.Len(), make([]string, 0, 1) + + for i := 1; i <= length; i++ { + + ind := reflect.Indirect(sind.Index(i - 1)) + + // Is this needed ? + // if !ind.Type().AssignableTo(typ) { + // return cnt, ErrArgs + // } + + if i == 1 { + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return cnt, err + } + values = make([]interface{}, bulk*len(vus)) + nums += copy(values, vus) + } else { + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + if err != nil { + return cnt, err + } + + if len(vus) != len(names) { + return cnt, ErrArgs + } + + nums += copy(values[nums:], vus) + } + + if i > 1 && i%bulk == 0 || length == i { + num, err := d.InsertValue(q, mi, true, names, values[:nums]) + if err != nil { + return cnt, err + } + cnt += num + nums = 0 + } + } + + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = "?" + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti && multi > 1 { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + args0 := "" + iouStr := "" + argsMap := map[string]string{} + switch a.Driver { + case DRMySQL: + iouStr = "ON DUPLICATE KEY UPDATE" + case DRPostgres: + if len(args) == 0 { + return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) + } + args0 = strings.ToLower(args[0]) + iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) + default: + return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) + } + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + var conflitValue interface{} + for i, v := range names { + // identifier in database may not be case-sensitive, so quote it + v = fmt.Sprintf("%s%s%s", Q, v, Q) + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if v == args0 { + conflitValue = values[i] + } + if valueStr != "" { + switch a.Driver { + case DRMySQL: + updates[i] = v + "=" + valueStr + case DRPostgres: + if conflitValue != nil { + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) + updateValues = append(updateValues, conflitValue) + } else { + return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) + } + } + } else { + updates[i] = v + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + if err != nil && err.Error() == `pq: syntax error at or near "ON"` { + err = fmt.Errorf("postgres version must 9.5 or higher") + } + return id, err +} + +// execute update sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + pkName, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + + var setNames []string + + // if specify cols length is zero, then commit all columns. + if len(cols) == 0 { + cols = mi.fields.dbcols + setNames = make([]string, 0, len(mi.fields.dbcols)-1) + } else { + setNames = make([]string, 0, len(cols)) + } + + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + if err != nil { + return 0, err + } + + var findAutoNowAdd, findAutoNow bool + var index int + for i, col := range setNames { + if mi.fields.GetByColumn(col).autoNowAdd { + index = i + findAutoNowAdd = true + } + if mi.fields.GetByColumn(col).autoNow { + findAutoNow = true + } + } + if findAutoNowAdd { + setNames = append(setNames[0:index], setNames[index+1:]...) + setValues = append(setValues[0:index], setValues[index+1:]...) + } + + if !findAutoNow { + for col, info := range mi.fields.columns { + if info.autoNow { + setNames = append(setNames, col) + setValues = append(setValues, time.Now()) + } + } + } + + setValues = append(setValues, pkValue) + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ?, %s", Q, Q) + setColumns := strings.Join(setNames, sep) + + query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + + d.ins.ReplaceMarks(&query) + + res, err := q.Exec(query, setValues...) + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + var whereCols []string + var args []interface{} + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return 0, err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) + + d.ins.ReplaceMarks(&query) + res, err := q.Exec(query, args...) + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) + } + } + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, err + } + return 0, err +} + +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. +func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { + columns := make([]string, 0, len(params)) + values := make([]interface{}, 0, len(params)) + for col, val := range params { + if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { + panic(fmt.Errorf("wrong field/column name `%s`", col)) + } else { + columns = append(columns, fi.column) + values = append(values, val) + } + } + + if len(columns) == 0 { + panic(fmt.Errorf("update params cannot empty")) + } + + tables := newDbTables(mi, d.ins) + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + where, args := tables.getCondSQL(cond, false, tz) + + values = append(values, args...) + + join := tables.getJoinSQL() + + var query, T string + + Q := d.ins.TableQuote() + + if d.ins.SupportUpdateJoin() { + T = "T0." + } + + cols := make([]string, 0, len(columns)) + + for i, v := range columns { + col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) + if c, ok := values[i].(colValue); ok { + switch c.opt { + case ColAdd: + cols = append(cols, col+" = "+col+" + ?") + case ColMinus: + cols = append(cols, col+" = "+col+" - ?") + case ColMultiply: + cols = append(cols, col+" = "+col+" * ?") + case ColExcept: + cols = append(cols, col+" = "+col+" / ?") + case ColBitAnd: + cols = append(cols, col+" = "+col+" & ?") + case ColBitRShift: + cols = append(cols, col+" = "+col+" >> ?") + case ColBitLShift: + cols = append(cols, col+" = "+col+" << ?") + case ColBitXOR: + cols = append(cols, col+" = "+col+" ^ ?") + case ColBitOr: + cols = append(cols, col+" = "+col+" | ?") + } + values[i] = c.value + } else { + cols = append(cols, col+" = ?") + } + } + + sets := strings.Join(cols, ", ") + " " + + if d.ins.SupportUpdateJoin() { + query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) + } else { + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) + } + + d.ins.ReplaceMarks(&query) + var err error + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, values...) + } else { + res, err = q.Exec(query, values...) + } + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. +func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { + for _, fi := range mi.fields.fieldsReverse { + fi = fi.reverseFieldInfo + switch fi.onDelete { + case odCascade: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + if err != nil { + return err + } + case odSetDefault, odSetNULL: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + params := Params{fi.column: nil} + if fi.onDelete == odSetDefault { + params[fi.column] = fi.initial.String() + } + _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + if err != nil { + return err + } + case odDoNothing: + } + } + return nil +} + +// delete table-related records. +func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { + tables := newDbTables(mi, d.ins) + tables.skipEnd = true + + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + if cond == nil || cond.IsEmpty() { + panic(fmt.Errorf("delete operation cannot execute without condition")) + } + + Q := d.ins.TableQuote() + + where, args := tables.getCondSQL(cond, false, tz) + join := tables.getJoinSQL() + + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + r, err := q.Query(query, args...) + if err != nil { + return 0, err + } + rs = r + defer rs.Close() + + var ref interface{} + args = make([]interface{}, 0) + cnt := 0 + for rs.Next() { + if err := rs.Scan(&ref); err != nil { + return 0, err + } + pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + if err != nil { + return 0, err + } + args = append(args, pkValue) + cnt++ + } + + if cnt == 0 { + return 0, nil + } + + marks := make([]string, len(args)) + for i := range marks { + marks[i] = "?" + } + sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) + + d.ins.ReplaceMarks(&query) + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, args...) + } else { + res, err = q.Exec(query, args...) + } + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, nil + } + return 0, err +} + +// read related records. +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { + + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + errTyp := true + one := true + isPtr := true + + if val.Kind() == reflect.Ptr { + fn := "" + if ind.Kind() == reflect.Slice { + one = false + typ := ind.Type().Elem() + switch typ.Kind() { + case reflect.Ptr: + fn = getFullName(typ.Elem()) + case reflect.Struct: + isPtr = false + fn = getFullName(typ) + } + } else { + fn = getFullName(ind.Type()) + } + errTyp = fn != mi.fullName + } + + if errTyp { + if one { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) + } else { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) + } + } + + rlimit := qs.limit + offset := qs.offset + + Q := d.ins.TableQuote() + + var tCols []string + if len(cols) > 0 { + hasRel := len(qs.related) > 0 || qs.relDepth > 0 + tCols = make([]string, 0, len(cols)) + var maps map[string]bool + if hasRel { + maps = make(map[string]bool) + } + for _, col := range cols { + if fi, ok := mi.fields.GetByAny(col); ok { + tCols = append(tCols, fi.column) + if hasRel { + maps[fi.column] = true + } + } else { + return 0, fmt.Errorf("wrong field/column name `%s`", col) + } + } + if hasRel { + for _, fi := range mi.fields.fieldsDB { + if fi.fieldType&IsRelField > 0 { + if !maps[fi.column] { + tCols = append(tCols, fi.column) + } + } + } + } + } else { + tCols = mi.fields.dbcols + } + + colsNum := len(tCols) + sep := fmt.Sprintf("%s, T0.%s", Q, Q) + sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) + + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, offset, rlimit) + join := tables.getJoinSQL() + + for _, tbl := range tables.tables { + if tbl.sel { + colsNum += len(tbl.mi.fields.dbcols) + sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) + sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + } + } + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + if qs.forupdate { + query += " FOR UPDATE" + } + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + var err error + if qs != nil && qs.forContext { + rs, err = q.QueryContext(qs.ctx, query, args...) + if err != nil { + return 0, err + } + } else { + rs, err = q.Query(query, args...) + if err != nil { + return 0, err + } + } + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + slice := ind + + var cnt int64 + for rs.Next() { + if one && cnt == 0 || !one { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + + cacheV := make(map[string]*reflect.Value) + cacheM := make(map[string]*modelInfo) + trefs := refs + + d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) + trefs = refs[len(tCols):] + + for _, tbl := range tables.tables { + // loop selected tables + if tbl.sel { + last := mind + names := "" + mmi := mi + // loop cascade models + for _, name := range tbl.names { + names += name + if val, ok := cacheV[names]; ok { + last = *val + mmi = cacheM[names] + } else { + fi := mmi.fields.GetByName(name) + lastm := mmi + mmi = fi.relModelInfo + field := last + if last.Kind() != reflect.Invalid { + field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) + if field.IsValid() { + d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) + for _, fi := range mmi.fields.fieldsReverse { + if fi.inModel && fi.reverseFieldInfo.mi == lastm { + if fi.reverseFieldInfo != nil { + f := field.FieldByIndex(fi.fieldIndex) + if f.Kind() == reflect.Ptr { + f.Set(last.Addr()) + } + } + } + } + last = field + } + } + cacheV[names] = &field + cacheM[names] = mmi + } + } + trefs = trefs[len(mmi.fields.dbcols):] + } + } + + if one { + ind.Set(mind) + } else { + if cnt == 0 { + // you can use a empty & caped container list + // orm will not replace it + if ind.Len() != 0 { + // if container is not empty + // create a new one + slice = reflect.New(ind.Type()).Elem() + } + } + + if isPtr { + slice = reflect.Append(slice, mind.Addr()) + } else { + slice = reflect.Append(slice, mind) + } + } + } + cnt++ + } + + if !one { + if cnt > 0 { + ind.Set(slice) + } else { + // when a result is empty and container is nil + // to set a empty container + if ind.IsNil() { + ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) + } + } + } + + return cnt, nil +} + +// excute count sql and return count result int64. +func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + tables.getOrderSQL(qs.orders) + join := tables.getJoinSQL() + + Q := d.ins.TableQuote() + + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + + if groupBy != "" { + query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) + } + + d.ins.ReplaceMarks(&query) + + var row *sql.Row + if qs != nil && qs.forContext { + row = q.QueryRowContext(qs.ctx, query, args...) + } else { + row = q.QueryRow(query, args...) + } + err = row.Scan(&cnt) + return +} + +// generate sql with replacing operator string placeholders and replaced values. +func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { + var sql string + params := getFlatParams(fi, args, tz) + + if len(params) == 0 { + panic(fmt.Errorf("operator `%s` need at least one args", operator)) + } + arg := params[0] + + switch operator { + case "in": + marks := make([]string, len(params)) + for i := range marks { + marks[i] = "?" + } + sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + case "between": + if len(params) != 2 { + panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) + } + sql = "BETWEEN ? AND ?" + default: + if len(params) > 1 { + panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) + } + sql = d.ins.OperatorSQL(operator) + switch operator { + case "exact": + if arg == nil { + params[0] = "IS NULL" + } + case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": + param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + switch operator { + case "iexact": + case "contains", "icontains": + param = fmt.Sprintf("%%%s%%", param) + case "startswith", "istartswith": + param = fmt.Sprintf("%s%%", param) + case "endswith", "iendswith": + param = fmt.Sprintf("%%%s", param) + } + params[0] = param + case "isnull": + if b, ok := arg.(bool); ok { + if b { + sql = "IS NULL" + } else { + sql = "IS NOT NULL" + } + params = nil + } else { + panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) + } + } + } + return sql, params +} + +// gernerate sql string with inner function, such as UPPER(text). +func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { + // default not use +} + +// set values to struct column. +func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { + for i, column := range cols { + val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() + + fi := mi.fields.GetByColumn(column) + + field := ind.FieldByIndex(fi.fieldIndex) + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + + _, err = d.setFieldValue(fi, value, field) + + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + } +} + +// convert value from database result to value following in field type. +func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { + if val == nil { + return nil, nil + } + + var value interface{} + var tErr error + + var str *StrTo + switch v := val.(type) { + case []byte: + s := StrTo(string(v)) + str = &s + case string: + s := StrTo(v) + str = &s + } + + fieldType := fi.fieldType + +setValue: + switch { + case fieldType == TypeBooleanField: + if str == nil { + switch v := val.(type) { + case int64: + b := v == 1 + value = b + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + b, err := str.Bool() + if err != nil { + tErr = err + goto end + } + value = b + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if str == nil { + value = ToStr(val) + } else { + value = str.String() + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if str == nil { + switch t := val.(type) { + case time.Time: + d.ins.TimeFromDB(&t, tz) + value = t + default: + s := StrTo(ToStr(t)) + str = &s + } + } + if str != nil { + s := str.String() + var ( + t time.Time + err error + ) + if len(s) >= 19 { + s = s[:19] + t, err = time.ParseInLocation(formatDateTime, s, tz) + } else if len(s) >= 10 { + if len(s) > 10 { + s = s[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else if len(s) >= 8 { + if len(s) > 8 { + s = s[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + t = t.In(DefaultTimeLoc) + + if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { + tErr = err + goto end + } + value = t + } + case fieldType&IsIntegerField > 0: + if str == nil { + s := StrTo(ToStr(val)) + str = &s + } + if str != nil { + var err error + switch fieldType { + case TypeBitField: + _, err = str.Int8() + case TypeSmallIntegerField: + _, err = str.Int16() + case TypeIntegerField: + _, err = str.Int32() + case TypeBigIntegerField: + _, err = str.Int64() + case TypePositiveBitField: + _, err = str.Uint8() + case TypePositiveSmallIntegerField: + _, err = str.Uint16() + case TypePositiveIntegerField: + _, err = str.Uint32() + case TypePositiveBigIntegerField: + _, err = str.Uint64() + } + if err != nil { + tErr = err + goto end + } + if fieldType&IsPositiveIntegerField > 0 { + v, _ := str.Uint64() + value = v + } else { + v, _ := str.Int64() + value = v + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if str == nil { + switch v := val.(type) { + case float64: + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + v, err := str.Float64() + if err != nil { + tErr = err + goto end + } + value = v + } + case fieldType&IsRelField > 0: + fi = fi.relModelInfo.fields.pk + fieldType = fi.fieldType + goto setValue + } + +end: + if tErr != nil { + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + return nil, err + } + + return value, nil + +} + +// set one value to struct column field. +func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { + + fieldType := fi.fieldType + isNative := !fi.isFielder + +setValue: + switch { + case fieldType == TypeBooleanField: + if isNative { + if nb, ok := field.Interface().(sql.NullBool); ok { + if value == nil { + nb.Valid = false + } else { + nb.Bool = value.(bool) + nb.Valid = true + } + field.Set(reflect.ValueOf(nb)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(bool) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = false + } + field.SetBool(value.(bool)) + } + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if isNative { + if ns, ok := field.Interface().(sql.NullString); ok { + if value == nil { + ns.Valid = false + } else { + ns.String = value.(string) + ns.Valid = true + } + field.Set(reflect.ValueOf(ns)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(string) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = "" + } + field.SetString(value.(string)) + } + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if isNative { + if value == nil { + value = time.Time{} + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(time.Time) + field.Set(reflect.ValueOf(&v)) + } + } else { + field.Set(reflect.ValueOf(value)) + } + } + case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint8(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint16(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(uint)) { + v := uint(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := uint32(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(uint64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := int8(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := int16(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(int)) { + v := int(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := int32(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(int64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType&IsIntegerField > 0: + if fieldType&IsPositiveIntegerField > 0 { + if isNative { + if value == nil { + value = uint64(0) + } + field.SetUint(value.(uint64)) + } + } else { + if isNative { + if ni, ok := field.Interface().(sql.NullInt64); ok { + if value == nil { + ni.Valid = false + } else { + ni.Int64 = value.(int64) + ni.Valid = true + } + field.Set(reflect.ValueOf(ni)) + } else { + if value == nil { + value = int64(0) + } + field.SetInt(value.(int64)) + } + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if isNative { + if nf, ok := field.Interface().(sql.NullFloat64); ok { + if value == nil { + nf.Valid = false + } else { + nf.Float64 = value.(float64) + nf.Valid = true + } + field.Set(reflect.ValueOf(nf)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + if field.Type() == reflect.TypeOf(new(float32)) { + v := float32(value.(float64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := value.(float64) + field.Set(reflect.ValueOf(&v)) + } + } + } else { + + if value == nil { + value = float64(0) + } + field.SetFloat(value.(float64)) + } + } + case fieldType&IsRelField > 0: + if value != nil { + fieldType = fi.relModelInfo.fields.pk.fieldType + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = f + goto setValue + } + } + + if !isNative { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) + return nil, err + } + } + + return value, nil +} + +// query sql, read values , save to *[]ParamList. +func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { + + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch v := container.(type) { + case *[]Params: + d := *v + if len(d) == 0 { + maps = d + } + typ = 1 + case *[]ParamsList: + d := *v + if len(d) == 0 { + lists = d + } + typ = 2 + case *ParamsList: + d := *v + if len(d) == 0 { + list = d + } + typ = 3 + default: + panic(fmt.Errorf("unsupport read values type `%T`", container)) + } + + tables := newDbTables(mi, d.ins) + + var ( + cols []string + infos []*fieldInfo + ) + + hasExprs := len(exprs) > 0 + + Q := d.ins.TableQuote() + + if hasExprs { + cols = make([]string, 0, len(exprs)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, ex := range exprs { + index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", ex)) + } + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) + infos = append(infos, fi) + } + } else { + cols = make([]string, 0, len(mi.fields.dbcols)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, fi := range mi.fields.fieldsDB { + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) + infos = append(infos, fi) + } + } + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, qs.offset, qs.limit) + join := tables.getJoinSQL() + + sels := strings.Join(cols, ", ") + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + d.ins.ReplaceMarks(&query) + + rs, err := q.Query(query, args...) + if err != nil { + return 0, err + } + refs := make([]interface{}, len(cols)) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + var ( + cnt int64 + columns []string + ) + for rs.Next() { + if cnt == 0 { + cols, err := rs.Columns() + if err != nil { + return 0, err + } + columns = cols + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params[columns[i]] = value + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params = append(params, value) + } + lists = append(lists, params) + case 3: + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + list = append(list, value) + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { + return 0, nil +} + +// flag of update joined record. +func (d *dbBase) SupportUpdateJoin() bool { + return true +} + +func (d *dbBase) MaxLimit() uint64 { + return 18446744073709551615 +} + +// return quote. +func (d *dbBase) TableQuote() string { + return "`" +} + +// replace value placeholder in parametered sql string. +func (d *dbBase) ReplaceMarks(query *string) { + // default use `?` as mark, do nothing +} + +// flag of RETURNING sql. +func (d *dbBase) HasReturningID(*modelInfo, *string) bool { + return false +} + +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + +// convert time from db. +func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// convert time to db. +func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// get database types. +func (d *dbBase) DbTypes() map[string]string { + return nil +} + +// gt all tables. +func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { + tables := make(map[string]bool) + query := d.ins.ShowTablesQuery() + rows, err := db.Query(query) + if err != nil { + return tables, err + } + + defer rows.Close() + + for rows.Next() { + var table string + err := rows.Scan(&table) + if err != nil { + return tables, err + } + if table != "" { + tables[table] = true + } + } + + return tables, nil +} + +// get all cloumns in table. +func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + columns := make(map[string][3]string) + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return columns, err + } + + defer rows.Close() + + for rows.Next() { + var ( + name string + typ string + null string + ) + err := rows.Scan(&name, &typ, &null) + if err != nil { + return columns, err + } + columns[name] = [3]string{name, typ, null} + } + + return columns, nil +} + +// not implement. +func (d *dbBase) OperatorSQL(operator string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowTablesQuery() string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowColumnsQuery(table string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) IndexExists(dbQuerier, string, string) bool { + panic(ErrNotImplement) +} diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go new file mode 100644 index 00000000..bf6c350c --- /dev/null +++ b/pkg/orm/db_alias.go @@ -0,0 +1,466 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + lru "github.com/hashicorp/golang-lru" + "reflect" + "sync" + "time" +) + +// DriverType database driver constant int. +type DriverType int + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL // mysql + DRSqlite // sqlite + DROracle // oracle + DRPostgres // pgsql + DRTiDB // TiDB +) + +// database driver string. +type driver string + +// get type constant int of current driver.. +func (d driver) Type() DriverType { + a, _ := dataBaseCache.get(string(d)) + return a.Driver +} + +// get name of current driver +func (d driver) Name() string { + return string(d) +} + +// check driver iis implemented Driver interface or not. +var _ Driver = new(driver) + +var ( + dataBaseCache = &_dbCache{cache: make(map[string]*alias)} + drivers = map[string]DriverType{ + "mysql": DRMySQL, + "postgres": DRPostgres, + "sqlite3": DRSqlite, + "tidb": DRTiDB, + "oracle": DROracle, + "oci8": DROracle, // github.com/mattn/go-oci8 + "ora": DROracle, //https://github.com/rana/ora + } + dbBasers = map[DriverType]dbBaser{ + DRMySQL: newdbBaseMysql(), + DRSqlite: newdbBaseSqlite(), + DROracle: newdbBaseOracle(), + DRPostgres: newdbBasePostgres(), + DRTiDB: newdbBaseTidb(), + } +) + +// database alias cacher. +type _dbCache struct { + mux sync.RWMutex + cache map[string]*alias +} + +// add database alias with original name. +func (ac *_dbCache) add(name string, al *alias) (added bool) { + ac.mux.Lock() + defer ac.mux.Unlock() + if _, ok := ac.cache[name]; !ok { + ac.cache[name] = al + added = true + } + return +} + +// get database alias if cached. +func (ac *_dbCache) get(name string) (al *alias, ok bool) { + ac.mux.RLock() + defer ac.mux.RUnlock() + al, ok = ac.cache[name] + return +} + +// get default alias. +func (ac *_dbCache) getDefault() (al *alias) { + al, _ = ac.get("default") + return +} + +type DB struct { + *sync.RWMutex + DB *sql.DB + stmtDecorators *lru.Cache +} + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +//su must call release to release *sql.Stmt after using +func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { + d.RLock() + c, ok := d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() + return c.(*stmtDecorator), nil + } + d.RUnlock() + + d.Lock() + c, ok = d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.Unlock() + return c.(*stmtDecorator), nil + } + + stmt, err := d.Prepare(query) + if err != nil { + d.Unlock() + return nil, err + } + sd := newStmtDecorator(stmt) + sd.acquire() + d.stmtDecorators.Add(query, sd) + d.Unlock() + + return sd, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Exec(args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Query(args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRow(args...) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRowContext(ctx, args) +} + +type alias struct { + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string +} + +func detectTZ(al *alias) { + // orm timezone system match database + // default use Local + al.TZ = DefaultTimeLoc + + if al.DriverName == "sphinx" { + return + } + + switch al.Driver { + case DRMySQL: + row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") + var tz string + row.Scan(&tz) + if len(tz) >= 8 { + if tz[0] != '-' { + tz = "+" + tz + } + t, err := time.Parse("-07:00:00", tz) + if err == nil { + if t.Location().String() != "" { + al.TZ = t.Location() + } + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } + + // get default engine from current database + row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") + var engine string + var tx bool + row.Scan(&engine, &tx) + + if engine != "" { + al.Engine = engine + } else { + al.Engine = "INNODB" + } + + case DRSqlite, DROracle: + al.TZ = time.UTC + + case DRPostgres: + row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") + var tz string + row.Scan(&tz) + loc, err := time.LoadLocation(tz) + if err == nil { + al.TZ = loc + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } +} + +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + _, err := addAliasWthDB(aliasName, driverName, db) + return err +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + + for i, v := range params { + switch i { + case 0: + SetMaxIdleConns(al.Name, v) + case 1: + SetMaxOpenConns(al.Name, v) + } + } + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + if t, ok := drivers[driverName]; !ok { + drivers[driverName] = typ + } else { + if t != typ { + return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) + } + } + return nil +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + if al, ok := dataBaseCache.get(aliasName); ok { + al.TZ = tz + } else { + return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) + } + return nil +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) + // for tip go 1.2 + if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { + fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) + } +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + var name string + if len(aliasNames) > 0 { + name = aliasNames[0] + } else { + name = "default" + } + al, ok := dataBaseCache.get(name) + if ok { + return al.DB.DB, nil + } + return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) +} + +type stmtDecorator struct { + wg sync.WaitGroup + stmt *sql.Stmt +} + +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this +func (s *stmtDecorator) acquire() { + s.wg.Add(1) +} + +func (s *stmtDecorator) release() { + s.wg.Done() +} + +//garbage recycle for stmt +func (s *stmtDecorator) destroy() { + go func() { + s.wg.Wait() + _ = s.stmt.Close() + }() +} + +func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { + return &stmtDecorator{ + stmt: sqlStmt, + } +} + +func newStmtDecoratorLruWithEvict() *lru.Cache { + cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { + value.(*stmtDecorator).destroy() + }) + return cache +} diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go new file mode 100644 index 00000000..6e99058e --- /dev/null +++ b/pkg/orm/db_mysql.go @@ -0,0 +1,183 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strings" +) + +// mysql operators. +var mysqlOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", + // "regex": "REGEXP BINARY ?", + // "iregex": "REGEXP ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE BINARY ?", + "endswith": "LIKE BINARY ?", + "istartswith": "LIKE ?", + "iendswith": "LIKE ?", +} + +// mysql column field types. +var mysqlTypes = map[string]string{ + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", +} + +// mysql dbBaser implementation. +type dbBaseMysql struct { + dbBase +} + +var _ dbBaser = new(dbBaseMysql) + +// get mysql operator. +func (d *dbBaseMysql) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseMysql) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseMysql) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseMysql) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +// Add "`" for mysql sql building +func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + var iouStr string + argsMap := map[string]string{} + + iouStr = "ON DUPLICATE KEY UPDATE" + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if valueStr != "" { + updates[i] = "`" + v + "`" + "=" + valueStr + } else { + updates[i] = "`" + v + "`" + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + return id, err +} + +// create new mysql dbBaser. +func newdbBaseMysql() dbBaser { + b := new(dbBaseMysql) + b.ins = b + return b +} diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go new file mode 100644 index 00000000..5d121f83 --- /dev/null +++ b/pkg/orm/db_oracle.go @@ -0,0 +1,137 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// oracle operators. +var oracleOperators = map[string]string{ + "exact": "= ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "//iendswith": "LIKE ?", +} + +// oracle column field types. +var oracleTypes = map[string]string{ + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", +} + +// oracle dbBaser +type dbBaseOracle struct { + dbBase +} + +var _ dbBaser = new(dbBaseOracle) + +// create oracle dbBaser. +func newdbBaseOracle() dbBaser { + b := new(dbBaseOracle) + b.ins = b + return b +} + +// OperatorSQL get oracle operator. +func (d *dbBaseOracle) OperatorSQL(operator string) string { + return oracleOperators[operator] +} + +// DbTypes get oracle table field types. +func (d *dbBaseOracle) DbTypes() map[string]string { + return oracleTypes +} + +//ShowTablesQuery show all the tables in database +func (d *dbBaseOracle) ShowTablesQuery() string { + return "SELECT TABLE_NAME FROM USER_TABLES" +} + +// Oracle +func (d *dbBaseOracle) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ + "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) +} + +// check index is exist +func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ + "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ + "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) + + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = ":" + names[i] + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go new file mode 100644 index 00000000..c488fb38 --- /dev/null +++ b/pkg/orm/db_postgres.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" +) + +// postgresql operators. +var postgresOperators = map[string]string{ + "exact": "= ?", + "iexact": "= UPPER(?)", + "contains": "LIKE ?", + "icontains": "LIKE UPPER(?)", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ?", + "endswith": "LIKE ?", + "istartswith": "LIKE UPPER(?)", + "iendswith": "LIKE UPPER(?)", +} + +// postgresql column field types. +var postgresTypes = map[string]string{ + "auto": "serial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", +} + +// postgresql dbBaser. +type dbBasePostgres struct { + dbBase +} + +var _ dbBaser = new(dbBasePostgres) + +// get postgresql operator. +func (d *dbBasePostgres) OperatorSQL(operator string) string { + return postgresOperators[operator] +} + +// generate functioned sql string, such as contains(text). +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + switch operator { + case "contains", "startswith", "endswith": + *leftCol = fmt.Sprintf("%s::text", *leftCol) + case "iexact", "icontains", "istartswith", "iendswith": + *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) + } +} + +// postgresql unsupports updating joined record. +func (d *dbBasePostgres) SupportUpdateJoin() bool { + return false +} + +func (d *dbBasePostgres) MaxLimit() uint64 { + return 0 +} + +// postgresql quote is ". +func (d *dbBasePostgres) TableQuote() string { + return `"` +} + +// postgresql value placeholder is $n. +// replace default ? to $n. +func (d *dbBasePostgres) ReplaceMarks(query *string) { + q := *query + num := 0 + for _, c := range q { + if c == '?' { + num++ + } + } + if num == 0 { + return + } + data := make([]byte, 0, len(q)+num) + num = 1 + for i := 0; i < len(q); i++ { + c := q[i] + if c == '?' { + data = append(data, '$') + data = append(data, []byte(strconv.Itoa(num))...) + num++ + } else { + data = append(data, c) + } + } + *query = string(data) +} + +// make returning sql support for postgresql. +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { + fi := mi.fields.pk + if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { + return false + } + + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + } + return true +} + +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil +} + +// show table sql for postgresql. +func (d *dbBasePostgres) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" +} + +// show table columns sql for postgresql. +func (d *dbBasePostgres) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) +} + +// get column types of postgresql. +func (d *dbBasePostgres) DbTypes() map[string]string { + return postgresTypes +} + +// check index exist in postgresql. +func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) + row := db.QueryRow(query) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new postgresql dbBaser. +func newdbBasePostgres() dbBaser { + b := new(dbBasePostgres) + b.ins = b + return b +} diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go new file mode 100644 index 00000000..1d62ee34 --- /dev/null +++ b/pkg/orm/db_sqlite.go @@ -0,0 +1,161 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// sqlite operators. +var sqliteOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ? ESCAPE '\\'", + "contains": "LIKE ? ESCAPE '\\'", + "icontains": "LIKE ? ESCAPE '\\'", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ? ESCAPE '\\'", + "endswith": "LIKE ? ESCAPE '\\'", + "istartswith": "LIKE ? ESCAPE '\\'", + "iendswith": "LIKE ? ESCAPE '\\'", +} + +// sqlite column types. +var sqliteTypes = map[string]string{ + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", +} + +// sqlite dbBaser. +type dbBaseSqlite struct { + dbBase +} + +var _ dbBaser = new(dbBaseSqlite) + +// override base db read for update behavior as SQlite does not support syntax +func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + if isForUpdate { + DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") + } + return d.dbBase.Read(q, mi, ind, tz, cols, false) +} + +// get sqlite operator. +func (d *dbBaseSqlite) OperatorSQL(operator string) string { + return sqliteOperators[operator] +} + +// generate functioned sql for sqlite. +// only support DATE(text). +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + if fi.fieldType == TypeDateField { + *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) + } +} + +// unable updating joined record in sqlite. +func (d *dbBaseSqlite) SupportUpdateJoin() bool { + return false +} + +// max int in sqlite. +func (d *dbBaseSqlite) MaxLimit() uint64 { + return 9223372036854775807 +} + +// get column types in sqlite. +func (d *dbBaseSqlite) DbTypes() map[string]string { + return sqliteTypes +} + +// get show tables sql in sqlite. +func (d *dbBaseSqlite) ShowTablesQuery() string { + return "SELECT name FROM sqlite_master WHERE type = 'table'" +} + +// get columns in sqlite. +func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return nil, err + } + + columns := make(map[string][3]string) + for rows.Next() { + var tmp, name, typ, null sql.NullString + err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) + if err != nil { + return nil, err + } + columns[name.String] = [3]string{name.String, typ.String, null.String} + } + + return columns, nil +} + +// get show columns sql in sqlite. +func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { + return fmt.Sprintf("pragma table_info('%s')", table) +} + +// check index exist in sqlite. +func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("PRAGMA index_list('%s')", table) + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() + for rows.Next() { + var tmp, index sql.NullString + rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) + if name == index.String { + return true + } + } + return false +} + +// create new sqlite dbBaser. +func newdbBaseSqlite() dbBaser { + b := new(dbBaseSqlite) + b.ins = b + return b +} diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go new file mode 100644 index 00000000..4b21a6fc --- /dev/null +++ b/pkg/orm/db_tables.go @@ -0,0 +1,482 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" + "time" +) + +// table info struct. +type dbTable struct { + id int + index string + name string + names []string + sel bool + inner bool + mi *modelInfo + fi *fieldInfo + jtl *dbTable +} + +// tables collection struct, contains some tables. +type dbTables struct { + tablesM map[string]*dbTable + tables []*dbTable + mi *modelInfo + base dbBaser + skipEnd bool +} + +// set table info to collection. +// if not exist, create new. +func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { + name := strings.Join(names, ExprSep) + if j, ok := t.tablesM[name]; ok { + j.name = name + j.mi = mi + j.fi = fi + j.inner = inner + } else { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + } + return t.tablesM[name] +} + +// add table info to collection. +func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { + name := strings.Join(names, ExprSep) + if _, ok := t.tablesM[name]; !ok { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + return jt, true + } + return t.tablesM[name], false +} + +// get table info in collection. +func (t *dbTables) get(name string) (*dbTable, bool) { + j, ok := t.tablesM[name] + return j, ok +} + +// get related fields info in recursive depth loop. +// loop once, depth decreases one. +func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { + if depth < 0 || fi.fieldType == RelManyToMany { + return related + } + + if prefix == "" { + prefix = fi.name + } else { + prefix = prefix + ExprSep + fi.name + } + related = append(related, prefix) + + depth-- + for _, fi := range fi.relModelInfo.fields.fieldsRel { + related = t.loopDepth(depth, prefix, fi, related) + } + + return related +} + +// parse related fields. +func (t *dbTables) parseRelated(rels []string, depth int) { + + relsNum := len(rels) + related := make([]string, relsNum) + copy(related, rels) + + relDepth := depth + + if relsNum != 0 { + relDepth = 0 + } + + relDepth-- + for _, fi := range t.mi.fields.fieldsRel { + related = t.loopDepth(relDepth, "", fi, related) + } + + for i, s := range related { + var ( + exs = strings.Split(s, ExprSep) + names = make([]string, 0, len(exs)) + mmi = t.mi + cancel = true + jtl *dbTable + ) + + inner := true + + for _, ex := range exs { + if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { + names = append(names, fi.name) + mmi = fi.relModelInfo + + if fi.null || t.skipEnd { + inner = false + } + + jt := t.set(names, mmi, fi, inner) + jt.jtl = jtl + + if fi.reverse { + cancel = false + } + + if cancel { + jt.sel = depth > 0 + + if i < relsNum { + jt.sel = true + } + } + + jtl = jt + + } else { + panic(fmt.Errorf("unknown model/table name `%s`", ex)) + } + } + } +} + +// generate join string. +func (t *dbTables) getJoinSQL() (join string) { + Q := t.base.TableQuote() + + for _, jt := range t.tables { + if jt.inner { + join += "INNER JOIN " + } else { + join += "LEFT OUTER JOIN " + } + var ( + table string + t1, t2 string + c1, c2 string + ) + t1 = "T0" + if jt.jtl != nil { + t1 = jt.jtl.index + } + t2 = jt.index + table = jt.mi.table + + switch { + case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + c1 = jt.fi.mi.fields.pk.column + for _, ffi := range jt.mi.fields.fieldsRel { + if jt.fi.mi == ffi.relModelInfo { + c2 = ffi.column + break + } + } + default: + c1 = jt.fi.column + c2 = jt.fi.relModelInfo.fields.pk.column + + if jt.fi.reverse { + c1 = jt.mi.fields.pk.column + c2 = jt.fi.reverseFieldInfo.column + } + } + + join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, + t2, Q, c2, Q, t1, Q, c1, Q) + } + return +} + +// parse orm model struct field tag expression. +func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { + var ( + jtl *dbTable + fi *fieldInfo + fiN *fieldInfo + mmi = mi + ) + + num := len(exprs) - 1 + var names []string + + inner := true + +loopFor: + for i, ex := range exprs { + + var ok, okN bool + + if fiN != nil { + fi = fiN + ok = true + fiN = nil + } + + if i == 0 { + fi, ok = mmi.fields.GetByAny(ex) + } + + _ = okN + + if ok { + + isRel := fi.rel || fi.reverse + + names = append(names, fi.name) + + switch { + case fi.rel: + mmi = fi.relModelInfo + if fi.fieldType == RelManyToMany { + mmi = fi.relThroughModelInfo + } + case fi.reverse: + mmi = fi.reverseFieldInfo.mi + } + + if i < num { + fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + } + + if isRel && (!fi.mi.isThrough || num != i) { + if fi.null || t.skipEnd { + inner = false + } + + if t.skipEnd && okN || !t.skipEnd { + if t.skipEnd && okN && fiN.pk { + goto loopEnd + } + + jt, _ := t.add(names, mmi, fi, inner) + jt.jtl = jtl + jtl = jt + } + + } + + if num != i { + continue + } + + loopEnd: + + if i == 0 || jtl == nil { + index = "T0" + } else { + index = jtl.index + } + + info = fi + + if jtl == nil { + name = fi.name + } else { + name = jtl.name + ExprSep + fi.name + } + + switch { + case fi.rel: + + case fi.reverse: + switch fi.reverseFieldInfo.fieldType { + case RelOneToOne, RelForeignKey: + index = jtl.index + info = fi.reverseFieldInfo.mi.fields.pk + name = info.name + } + } + + break loopFor + + } else { + index = "" + name = "" + info = nil + success = false + return + } + } + + success = index != "" && info != nil + return +} + +// generate condition sql. +func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { + if cond == nil || cond.IsEmpty() { + return + } + + Q := t.base.TableQuote() + + mi := t.mi + + for i, p := range cond.params { + if i > 0 { + if p.isOr { + where += "OR " + } else { + where += "AND " + } + } + if p.isNot { + where += "NOT " + } + if p.isCond { + w, ps := t.getCondSQL(p.cond, true, tz) + if w != "" { + w = fmt.Sprintf("( %s) ", w) + } + where += w + params = append(params, ps...) + } else { + exprs := p.exprs + + num := len(exprs) - 1 + operator := "" + if operators[exprs[num]] { + operator = exprs[num] + exprs = exprs[:num] + } + + index, _, fi, suc := t.parseExprs(mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) + } + + if operator == "" { + operator = "exact" + } + + var operSQL string + var args []interface{} + if p.isRaw { + operSQL = p.sql + } else { + operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) + } + + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) + + where += fmt.Sprintf("%s %s ", leftCol, operSQL) + params = append(params, args...) + + } + } + + if !sub && where != "" { + where = "WHERE " + where + } + + return +} + +// generate group sql. +func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { + if len(groups) == 0 { + return + } + + Q := t.base.TableQuote() + + groupSqls := make([]string, 0, len(groups)) + for _, group := range groups { + exprs := strings.Split(group, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + } + + groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) + return +} + +// generate order sql. +func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { + if len(orders) == 0 { + return + } + + Q := t.base.TableQuote() + + orderSqls := make([]string, 0, len(orders)) + for _, order := range orders { + asc := "ASC" + if order[0] == '-' { + asc = "DESC" + order = order[1:] + } + exprs := strings.Split(order, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + } + + orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + return +} + +// generate limit sql. +func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { + if limit == 0 { + limit = int64(DefaultRowsLimit) + } + if limit < 0 { + // no limit + if offset > 0 { + maxLimit := t.base.MaxLimit() + if maxLimit == 0 { + limits = fmt.Sprintf("OFFSET %d", offset) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + } + } + } else if offset <= 0 { + limits = fmt.Sprintf("LIMIT %d", limit) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) + } + return +} + +// crete new tables collection. +func newDbTables(mi *modelInfo, base dbBaser) *dbTables { + tables := &dbTables{} + tables.tablesM = make(map[string]*dbTable) + tables.mi = mi + tables.base = base + return tables +} diff --git a/pkg/orm/db_tidb.go b/pkg/orm/db_tidb.go new file mode 100644 index 00000000..6020a488 --- /dev/null +++ b/pkg/orm/db_tidb.go @@ -0,0 +1,63 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" +) + +// mysql dbBaser implementation. +type dbBaseTidb struct { + dbBase +} + +var _ dbBaser = new(dbBaseTidb) + +// get mysql operator. +func (d *dbBaseTidb) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseTidb) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseTidb) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseTidb) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new mysql dbBaser. +func newdbBaseTidb() dbBaser { + b := new(dbBaseTidb) + b.ins = b + return b +} diff --git a/pkg/orm/db_utils.go b/pkg/orm/db_utils.go new file mode 100644 index 00000000..7ae10ca5 --- /dev/null +++ b/pkg/orm/db_utils.go @@ -0,0 +1,177 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "time" +) + +// get table alias. +func getDbAlias(name string) *alias { + if al, ok := dataBaseCache.get(name); ok { + return al + } + panic(fmt.Errorf("unknown DataBase alias name %s", name)) +} + +// get pk column info. +func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.fields.pk + + v := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsPositiveIntegerField > 0 { + vu := v.Uint() + exist = vu > 0 + value = vu + } else if fi.fieldType&IsIntegerField > 0 { + vu := v.Int() + exist = true + value = vu + } else if fi.fieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) + } else { + vu := v.String() + exist = vu != "" + value = vu + } + + column = fi.column + return +} + +// get fields description as flatted string. +func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { + +outFor: + for _, arg := range args { + val := reflect.ValueOf(arg) + + if arg == nil { + params = append(params, arg) + continue + } + + kind := val.Kind() + if kind == reflect.Ptr { + val = val.Elem() + kind = val.Kind() + arg = val.Interface() + } + + switch kind { + case reflect.String: + v := val.String() + if fi != nil { + if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + var t time.Time + var err error + if len(v) >= 19 { + s := v[:19] + t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) + } else if len(v) >= 10 { + s := v + if len(v) > 10 { + s = v[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else { + s := v + if len(s) > 8 { + s = v[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + if err == nil { + if fi.fieldType == TypeDateField { + v = t.In(tz).Format(formatDate) + } else if fi.fieldType == TypeDateTimeField { + v = t.In(tz).Format(formatDateTime) + } else { + v = t.In(tz).Format(formatTime) + } + } + } + } + arg = v + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg = val.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + arg = val.Uint() + case reflect.Float32: + arg, _ = StrTo(ToStr(arg)).Float64() + case reflect.Float64: + arg = val.Float() + case reflect.Bool: + arg = val.Bool() + case reflect.Slice, reflect.Array: + if _, ok := arg.([]byte); ok { + continue outFor + } + + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } + + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := getFlatParams(fi, args, tz) + params = append(params, p...) + } + continue outFor + case reflect.Struct: + if v, ok := arg.(time.Time); ok { + if fi != nil && fi.fieldType == TypeDateField { + arg = v.In(tz).Format(formatDate) + } else if fi != nil && fi.fieldType == TypeDateTimeField { + arg = v.In(tz).Format(formatDateTime) + } else if fi != nil && fi.fieldType == TypeTimeField { + arg = v.In(tz).Format(formatTime) + } else { + arg = v.In(tz).Format(formatDateTime) + } + } else { + typ := val.Type() + name := getFullName(typ) + var value interface{} + if mmi, ok := modelCache.getByFullName(name); ok { + if _, vu, exist := getExistPk(mmi, val); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) + } + } + } + + params = append(params, arg) + } + return +} diff --git a/pkg/orm/models.go b/pkg/orm/models.go new file mode 100644 index 00000000..4776bcba --- /dev/null +++ b/pkg/orm/models.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "sync" +) + +const ( + odCascade = "cascade" + odSetNULL = "set_null" + odSetDefault = "set_default" + odDoNothing = "do_nothing" + defaultStructTagName = "orm" + defaultStructTagDelim = ";" +) + +var ( + modelCache = &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +) + +// model info collection +type _modelCache struct { + sync.RWMutex // only used outsite for bootStrap + orders []string + cache map[string]*modelInfo + cacheByFullName map[string]*modelInfo + done bool +} + +// get all model info +func (mc *_modelCache) all() map[string]*modelInfo { + m := make(map[string]*modelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +// get ordered model info +func (mc *_modelCache) allOrdered() []*modelInfo { + m := make([]*modelInfo, 0, len(mc.orders)) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) + } + return m +} + +// get model info by table name +func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +// get model info by full name +func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] + return +} + +// set model info to collection +func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + mc.cacheByFullName[mi.fullName] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} + +// clean all model info. +func (mc *_modelCache) clean() { + mc.orders = make([]string, 0) + mc.cache = make(map[string]*modelInfo) + mc.cacheByFullName = make(map[string]*modelInfo) + mc.done = false +} + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + modelCache.clean() +} diff --git a/pkg/orm/models_boot.go b/pkg/orm/models_boot.go new file mode 100644 index 00000000..8c56b3c4 --- /dev/null +++ b/pkg/orm/models_boot.go @@ -0,0 +1,347 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" + "runtime/debug" + "strings" +) + +// register models. +// PrefixOrSuffix means table name prefix or suffix. +// isPrefix whether the prefix is prefix or suffix +func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) + } + + table := getTableName(val) + + if PrefixOrSuffix != "" { + if isPrefix { + table = PrefixOrSuffix + table + } else { + table = table + PrefixOrSuffix + } + } + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := modelCache.getByFullName(name); ok { + fmt.Printf(" model `%s` repeat register, must be unique\n", name) + os.Exit(2) + } + + if _, ok := modelCache.get(table); ok { + fmt.Printf(" table name `%s` repeat register, must be unique\n", table) + os.Exit(2) + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + os.Exit(2) + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + modelCache.set(table, mi) +} + +// bootstrap models +func bootStrap() { + if modelCache.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := modelCache.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := modelCache.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := modelCache.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + os.Exit(2) + } +} + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModel must be run before BootStrap")) + } + RegisterModelWithPrefix("", models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(prefix, model, true) + } +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(suffix, model, false) + } +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + modelCache.Lock() + defer modelCache.Unlock() + if modelCache.done { + return + } + bootStrap() + modelCache.done = true +} diff --git a/pkg/orm/models_fields.go b/pkg/orm/models_fields.go new file mode 100644 index 00000000..b4fad94f --- /dev/null +++ b/pkg/orm/models_fields.go @@ -0,0 +1,783 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "time" +) + +// Define the Type enum +const ( + TypeBooleanField = 1 << iota + TypeVarCharField + TypeCharField + TypeTextField + TypeTimeField + TypeDateField + TypeDateTimeField + TypeBitField + TypeSmallIntegerField + TypeIntegerField + TypeBigIntegerField + TypePositiveBitField + TypePositiveSmallIntegerField + TypePositiveIntegerField + TypePositiveBigIntegerField + TypeFloatField + TypeDecimalField + TypeJSONField + TypeJsonbField + RelForeignKey + RelOneToOne + RelManyToMany + RelReverseOne + RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 + IsRelField = ^-RelReverseMany >> 18 << 19 + IsFieldType = ^-RelReverseMany<<1 + 1 +) + +// BooleanField A true/false field. +type BooleanField bool + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return bool(e) +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + *e = BooleanField(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return strconv.FormatBool(e.Value()) +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return TypeBooleanField +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + switch d := value.(type) { + case bool: + e.Set(d) + case string: + v, err := StrTo(d).Bool() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return e.Value() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField string + +// Value return the CharField's Value +func (e CharField) Value() string { + return string(e) +} + +// Set CharField value +func (e *CharField) Set(d string) { + *e = CharField(d) +} + +// String return the CharField +func (e *CharField) String() string { + return e.Value() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return TypeVarCharField +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return e.Value() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField time.Time + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + *e = TimeField(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return e.Value() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField time.Time + +// Value return the time.Time +func (e DateField) Value() time.Time { + return time.Time(e) +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + *e = DateField(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDate) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return e.Value() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField time.Time + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + *e = DateTimeField(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return e.Value().String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return TypeDateTimeField +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDateTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return e.Value() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField float64 + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return float64(e) +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + *e = FloatField(d) +} + +// String return the string +func (e *FloatField) String() string { + return ToStr(e.Value(), -1, 32) +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return TypeFloatField +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + switch d := value.(type) { + case float32: + e.Set(float64(d)) + case float64: + e.Set(d) + case string: + v, err := StrTo(d).Float64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return e.Value() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField int16 + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return int16(e) +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + *e = SmallIntegerField(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return TypeSmallIntegerField +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int16: + e.Set(d) + case string: + v, err := StrTo(d).Int16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField int32 + +// Value return the int32 +func (e IntegerField) Value() int32 { + return int32(e) +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + *e = IntegerField(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return TypeIntegerField +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int32: + e.Set(d) + case string: + v, err := StrTo(d).Int32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return e.Value() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField int64 + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return int64(e) +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + *e = BigIntegerField(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return TypeBigIntegerField +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int64: + e.Set(d) + case string: + v, err := StrTo(d).Int64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField uint16 + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return uint16(e) +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + *e = PositiveSmallIntegerField(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return TypePositiveSmallIntegerField +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint16: + e.Set(d) + case string: + v, err := StrTo(d).Uint16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField uint32 + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return uint32(e) +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + *e = PositiveIntegerField(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint32: + e.Set(d) + case string: + v, err := StrTo(d).Uint32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField uint64 + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return uint64(e) +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + *e = PositiveBigIntegerField(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint64: + e.Set(d) + case string: + v, err := StrTo(d).Uint64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField string + +// Value return TextField value +func (e TextField) Value() string { + return string(e) +} + +// Set the TextField value +func (e *TextField) Set(d string) { + *e = TextField(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return e.Value() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return TypeTextField +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return e.Value() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField string + +// Value return JSONField value +func (j JSONField) Value() string { + return string(j) +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + *j = JSONField(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return TypeJSONField +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return j.Value() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField string + +// Value return JsonbField value +func (j JsonbField) Value() string { + return string(j) +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + *j = JsonbField(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return TypeJsonbField +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return j.Value() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/orm/models_info_f.go b/pkg/orm/models_info_f.go new file mode 100644 index 00000000..7044b0bd --- /dev/null +++ b/pkg/orm/models_info_f.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var errSkipField = errors.New("skip field") + +// field info collection +type fields struct { + pk *fieldInfo + columns map[string]*fieldInfo + fields map[string]*fieldInfo + fieldsLow map[string]*fieldInfo + fieldsByType map[int][]*fieldInfo + fieldsRel []*fieldInfo + fieldsReverse []*fieldInfo + fieldsDB []*fieldInfo + rels []*fieldInfo + orders []string + dbcols []string +} + +// add field info +func (f *fields) Add(fi *fieldInfo) (added bool) { + if f.fields[fi.name] == nil && f.columns[fi.column] == nil { + f.columns[fi.column] = fi + f.fields[fi.name] = fi + f.fieldsLow[strings.ToLower(fi.name)] = fi + } else { + return + } + if _, ok := f.fieldsByType[fi.fieldType]; !ok { + f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + } + f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) + f.orders = append(f.orders, fi.column) + if fi.dbcol { + f.dbcols = append(f.dbcols, fi.column) + f.fieldsDB = append(f.fieldsDB, fi) + } + if fi.rel { + f.fieldsRel = append(f.fieldsRel, fi) + } + if fi.reverse { + f.fieldsReverse = append(f.fieldsReverse, fi) + } + return true +} + +// get field info by name +func (f *fields) GetByName(name string) *fieldInfo { + return f.fields[name] +} + +// get field info by column name +func (f *fields) GetByColumn(column string) *fieldInfo { + return f.columns[column] +} + +// get field info by string, name is prior +func (f *fields) GetByAny(name string) (*fieldInfo, bool) { + if fi, ok := f.fields[name]; ok { + return fi, ok + } + if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + return fi, ok + } + if fi, ok := f.columns[name]; ok { + return fi, ok + } + return nil, false +} + +// create new field info collection +func newFields() *fields { + f := new(fields) + f.fields = make(map[string]*fieldInfo) + f.fieldsLow = make(map[string]*fieldInfo) + f.columns = make(map[string]*fieldInfo) + f.fieldsByType = make(map[int][]*fieldInfo) + return f +} + +// single field info +type fieldInfo struct { + mi *modelInfo + fieldIndex []int + fieldType int + dbcol bool // table column fk and onetoone + inModel bool + name string + fullName string + column string + addrValue reflect.Value + sf reflect.StructField + auto bool + pk bool + null bool + index bool + unique bool + colDefault bool // whether has default tag + initial StrTo // store the default value + size int + toText bool + autoNow bool + autoNowAdd bool + rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true + reverse bool + reverseField string + reverseFieldInfo *fieldInfo + reverseFieldInfoTwo *fieldInfo + reverseFieldInfoM2M *fieldInfo + relTable string + relThrough string + relThroughModelInfo *modelInfo + relModelInfo *modelInfo + digits int + decimals int + isFielder bool // implement Fielder interface + onDelete string + description string +} + +// new field info +func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { + var ( + tag string + tagValue string + initial StrTo // store the default value + fieldType int + attrs map[string]bool + tags map[string]string + addrField reflect.Value + ) + + fi = new(fieldInfo) + + // if field which CanAddr is the follow type + // A value is addressable if it is an element of a slice, + // an element of an addressable array, a field of an + // addressable struct, or the result of dereferencing a pointer. + addrField = field + if field.CanAddr() && field.Kind() != reflect.Ptr { + addrField = field.Addr() + if _, ok := addrField.Interface().(Fielder); !ok { + if field.Kind() == reflect.Slice { + addrField = field + } + } + } + + attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) + + if _, ok := attrs["-"]; ok { + return nil, errSkipField + } + + digits := tags["digits"] + decimals := tags["decimals"] + size := tags["size"] + onDelete := tags["on_delete"] + + initial.Clear() + if v, ok := tags["default"]; ok { + initial.Set(v) + } + +checkType: + switch f := addrField.Interface().(type) { + case Fielder: + fi.isFielder = true + if field.Kind() == reflect.Ptr { + err = fmt.Errorf("the model Fielder can not be use ptr") + goto end + } + fieldType = f.FieldType() + if fieldType&IsRelField > 0 { + err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") + goto end + } + default: + tag = "rel" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "fk": + fieldType = RelForeignKey + break checkType + case "one": + fieldType = RelOneToOne + break checkType + case "m2m": + fieldType = RelManyToMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("rel only allow these value: fk, one, m2m") + goto wrongTag + } + } + tag = "reverse" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "one": + fieldType = RelReverseOne + break checkType + case "many": + fieldType = RelReverseMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("reverse only allow these value: one, many") + goto wrongTag + } + } + + fieldType, err = getFieldType(addrField) + if err != nil { + goto end + } + if fieldType == TypeVarCharField { + switch tags["type"] { + case "char": + fieldType = TypeCharField + case "text": + fieldType = TypeTextField + case "json": + fieldType = TypeJSONField + case "jsonb": + fieldType = TypeJsonbField + } + } + if fieldType == TypeFloatField && (digits != "" || decimals != "") { + fieldType = TypeDecimalField + } + if fieldType == TypeDateTimeField && tags["type"] == "date" { + fieldType = TypeDateField + } + if fieldType == TypeTimeField && tags["type"] == "time" { + fieldType = TypeTimeField + } + } + + // check the rel and reverse type + // rel should Ptr + // reverse should slice []*struct + switch fieldType { + case RelForeignKey, RelOneToOne, RelReverseOne: + if field.Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) + goto end + } + case RelManyToMany, RelReverseMany: + if field.Kind() != reflect.Slice { + err = fmt.Errorf("rel/reverse:many field must be slice") + goto end + } else { + if field.Type().Elem().Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) + goto end + } + } + } + + if fieldType&IsFieldType == 0 { + err = fmt.Errorf("wrong field type") + goto end + } + + fi.fieldType = fieldType + fi.name = sf.Name + fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.addrValue = addrField + fi.sf = sf + fi.fullName = mi.fullName + mName + "." + sf.Name + + fi.description = tags["description"] + fi.null = attrs["null"] + fi.index = attrs["index"] + fi.auto = attrs["auto"] + fi.pk = attrs["pk"] + fi.unique = attrs["unique"] + + // Mark object property if there is attribute "default" in the orm configuration + if _, ok := tags["default"]; ok { + fi.colDefault = true + } + + switch fieldType { + case RelManyToMany, RelReverseMany, RelReverseOne: + fi.null = false + fi.index = false + fi.auto = false + fi.pk = false + fi.unique = false + default: + fi.dbcol = true + } + + switch fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + fi.rel = true + if fieldType == RelOneToOne { + fi.unique = true + } + case RelReverseMany, RelReverseOne: + fi.reverse = true + } + + if fi.rel && fi.dbcol { + switch onDelete { + case odCascade, odDoNothing: + case odSetDefault: + if !initial.Exist() { + err = errors.New("on_delete: set_default need set field a default value") + goto end + } + case odSetNULL: + if !fi.null { + err = errors.New("on_delete: set_null need set field null") + goto end + } + default: + if onDelete == "" { + onDelete = odCascade + } else { + err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) + goto end + } + } + + fi.onDelete = onDelete + } + + switch fieldType { + case TypeBooleanField: + case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: + if size != "" { + v, e := StrTo(size).Int32() + if e != nil { + err = fmt.Errorf("wrong size value `%s`", size) + } else { + fi.size = int(v) + } + } else { + fi.size = 255 + fi.toText = true + } + case TypeTextField: + fi.index = false + fi.unique = false + case TypeTimeField, TypeDateField, TypeDateTimeField: + if attrs["auto_now"] { + fi.autoNow = true + } else if attrs["auto_now_add"] { + fi.autoNowAdd = true + } + case TypeFloatField: + case TypeDecimalField: + d1 := digits + d2 := decimals + v1, er1 := StrTo(d1).Int8() + v2, er2 := StrTo(d2).Int8() + if er1 != nil || er2 != nil { + err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) + goto end + } + fi.digits = int(v1) + fi.decimals = int(v2) + default: + switch { + case fieldType&IsIntegerField > 0: + case fieldType&IsRelField > 0: + } + } + + if fieldType&IsIntegerField == 0 { + if fi.auto { + err = fmt.Errorf("non-integer type cannot set auto") + goto end + } + } + + if fi.auto || fi.pk { + if fi.auto { + switch addrField.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + default: + err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) + goto end + } + fi.pk = true + } + fi.null = false + fi.index = false + fi.unique = false + } + + if fi.unique { + fi.index = false + } + + // can not set default for these type + if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { + initial.Clear() + } + + if initial.Exist() { + v := initial + switch fieldType { + case TypeBooleanField: + _, err = v.Bool() + case TypeFloatField, TypeDecimalField: + _, err = v.Float64() + case TypeBitField: + _, err = v.Int8() + case TypeSmallIntegerField: + _, err = v.Int16() + case TypeIntegerField: + _, err = v.Int32() + case TypeBigIntegerField: + _, err = v.Int64() + case TypePositiveBitField: + _, err = v.Uint8() + case TypePositiveSmallIntegerField: + _, err = v.Uint16() + case TypePositiveIntegerField: + _, err = v.Uint32() + case TypePositiveBigIntegerField: + _, err = v.Uint64() + } + if err != nil { + tag, tagValue = "default", tags["default"] + goto wrongTag + } + } + + fi.initial = initial +end: + if err != nil { + return nil, err + } + return +wrongTag: + return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) +} diff --git a/pkg/orm/models_info_m.go b/pkg/orm/models_info_m.go new file mode 100644 index 00000000..a4d733b6 --- /dev/null +++ b/pkg/orm/models_info_m.go @@ -0,0 +1,148 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" +) + +// single model info +type modelInfo struct { + pkg string + name string + fullName string + table string + model interface{} + fields *fields + manual bool + addrField reflect.Value //store the original struct value + uniques []string + isThrough bool +} + +// new model info +func newModelInfo(val reflect.Value) (mi *modelInfo) { + mi = &modelInfo{} + mi.fields = newFields() + ind := reflect.Indirect(val) + mi.addrField = val + mi.name = ind.Type().Name() + mi.fullName = getFullName(ind.Type()) + addModelFields(mi, ind, "", []int{}) + return +} + +// index: FieldByIndex returns the nested field corresponding to index +func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { + var ( + err error + fi *fieldInfo + sf reflect.StructField + ) + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + // if the field is unexported skip + if sf.PkgPath != "" { + continue + } + // add anonymous struct fields + if sf.Anonymous { + addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) + continue + } + + fi, err = newFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { + break + } + //record current field index + fi.fieldIndex = append(fi.fieldIndex, index...) + fi.fieldIndex = append(fi.fieldIndex, i) + fi.mi = mi + fi.inModel = true + if !mi.fields.Add(fi) { + err = fmt.Errorf("duplicate column name: %s", fi.column) + break + } + if fi.pk { + if mi.fields.pk != nil { + err = fmt.Errorf("one model must have one pk field only") + break + } else { + mi.fields.pk = fi + } + } + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } +} + +// combine related model info to new model info. +// prepare for relation models query. +func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { + mi = new(modelInfo) + mi.fields = newFields() + mi.table = m1.table + "_" + m2.table + "s" + mi.name = camelString(mi.table) + mi.fullName = m1.pkg + "." + mi.name + + fa := new(fieldInfo) // pk + f1 := new(fieldInfo) // m1 table RelForeignKey + f2 := new(fieldInfo) // m2 table RelForeignKey + fa.fieldType = TypeBigIntegerField + fa.auto = true + fa.pk = true + fa.dbcol = true + fa.name = "Id" + fa.column = "id" + fa.fullName = mi.fullName + "." + fa.name + + f1.dbcol = true + f2.dbcol = true + f1.fieldType = RelForeignKey + f2.fieldType = RelForeignKey + f1.name = camelString(m1.table) + f2.name = camelString(m2.table) + f1.fullName = mi.fullName + "." + f1.name + f2.fullName = mi.fullName + "." + f2.name + f1.column = m1.table + "_id" + f2.column = m2.table + "_id" + f1.rel = true + f2.rel = true + f1.relTable = m1.table + f2.relTable = m2.table + f1.relModelInfo = m1 + f2.relModelInfo = m2 + f1.mi = mi + f2.mi = mi + + mi.fields.Add(fa) + mi.fields.Add(f1) + mi.fields.Add(f2) + mi.fields.pk = fa + + mi.uniques = []string{f1.column, f2.column} + return +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go new file mode 100644 index 00000000..e3a635f2 --- /dev/null +++ b/pkg/orm/models_test.go @@ -0,0 +1,497 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" +) + +// A slice string field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeVarCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(SliceStringField) + +// A json field. +type JSONFieldTest struct { + Name string + Data string +} + +func (e *JSONFieldTest) String() string { + data, _ := json.Marshal(e) + return string(data) +} + +func (e *JSONFieldTest) FieldType() int { + return TypeTextField +} + +func (e *JSONFieldTest) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + return json.Unmarshal([]byte(d), e) + default: + return fmt.Errorf(" unknown value `%v`", value) + } +} + +func (e *JSONFieldTest) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(JSONFieldTest) + +type Data struct { + ID int `orm:"column(id)"` + Boolean bool + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` + Date time.Time `orm:"type(date)"` + DateTime time.Time `orm:"column(datetime)"` + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + ID int `orm:"column(id)"` + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + +// only for mysql +type UserBig struct { + ID uint64 `orm:"column(id)"` + Name string +} + +type User struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool +} + +func (u *User) TableIndex() [][]string { + return [][]string{ + {"Id", "UserName"}, + {"Id", "Created"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + {"UserName", "Email"}, + } +} + +func NewUser() *User { + obj := new(User) + return obj +} + +type Profile struct { + ID int `orm:"column(id)"` + Age int16 + Money float64 + User *User `orm:"reverse(one)" json:"-"` + BestPost *Post `orm:"rel(one);null"` +} + +func (u *Profile) TableName() string { + return "user_profile" +} + +func NewProfile() *Profile { + obj := new(Profile) + return obj +} + +type Post struct { + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` +} + +func (u *Post) TableIndex() [][]string { + return [][]string{ + {"Id", "Created"}, + } +} + +func NewPost() *Post { + obj := new(Post) + return obj +} + +type Tag struct { + ID int `orm:"column(id)"` + Name string `orm:"size(30)"` + BestPost *Post `orm:"rel(one);null"` + Posts []*Post `orm:"reverse(many)" json:"-"` +} + +func NewTag() *Tag { + obj := new(Tag) + return obj +} + +type PostTags struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk)"` + Tag *Tag `orm:"rel(fk)"` +} + +func (m *PostTags) TableName() string { + return "prefix_post_tags" +} + +type Comment struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk);column(post)"` + Content string `orm:"type(text)"` + Parent *Comment `orm:"null;rel(fk)"` + Created time.Time `orm:"auto_now_add"` +} + +func NewComment() *Comment { + obj := new(Comment) + return obj +} + +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + +type ModelID struct { + ID int64 +} + +type ModelBase struct { + ModelID + + Created time.Time `orm:"auto_now_add;type(datetime)"` + Updated time.Time `orm:"auto_now;type(datetime)"` +} + +type InLine struct { + // Common Fields + ModelBase + + // Other Fields + Name string `orm:"unique"` + Email string +} + +func NewInLine() *InLine { + return new(InLine) +} + +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" +) + +var ( + dORM Ormer + dDbBaser dbBaser +) + +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + +func init() { + Debug, _ = StrTo(DBARGS.Debug).Bool() + + if DBARGS.Driver == "" || DBARGS.Source == "" { + fmt.Println(helpinfo) + os.Exit(2) + } + + RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DRMySQL { + alias.Engine = "INNODB" + } + +} diff --git a/pkg/orm/models_utils.go b/pkg/orm/models_utils.go new file mode 100644 index 00000000..71127a6b --- /dev/null +++ b/pkg/orm/models_utils.go @@ -0,0 +1,227 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +// 1 is attr +// 2 is tag +var supportTag = map[string]int{ + "-": 1, + "null": 1, + "index": 1, + "unique": 1, + "pk": 1, + "auto": 1, + "auto_now": 1, + "auto_now_add": 1, + "size": 2, + "column": 2, + "default": 2, + "rel": 2, + "reverse": 2, + "rel_table": 2, + "rel_through": 2, + "digits": 2, + "decimals": 2, + "on_delete": 2, + "type": 2, + "description": 2, +} + +// get reflect.Type name with package path. +func getFullName(typ reflect.Type) string { + return typ.PkgPath() + "." + typ.Name() +} + +// getTableName get struct table name. +// If the struct implement the TableName, then get the result as tablename +// else use the struct name which will apply snakeString. +func getTableName(val reflect.Value) string { + if fun := val.MethodByName("TableName"); fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + // has return and the first val is string + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return snakeString(reflect.Indirect(val).Type().Name()) +} + +// get table engine, myisam or innodb. +func getTableEngine(val reflect.Value) string { + fun := val.MethodByName("TableEngine") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return "" +} + +// get table index from method. +func getTableIndex(val reflect.Value) [][]string { + fun := val.MethodByName("TableIndex") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get table unique from method +func getTableUnique(val reflect.Value) [][]string { + fun := val.MethodByName("TableUnique") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get snaked column name +func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { + column := col + if col == "" { + column = nameStrategyMap[nameStrategy](sf.Name) + } + switch ft { + case RelForeignKey, RelOneToOne: + if len(col) == 0 { + column = column + "_id" + } + case RelManyToMany, RelReverseMany, RelReverseOne: + column = sf.Name + } + return column +} + +// return field type as type constant from reflect.Value +func getFieldType(val reflect.Value) (ft int, err error) { + switch val.Type() { + case reflect.TypeOf(new(int8)): + ft = TypeBitField + case reflect.TypeOf(new(int16)): + ft = TypeSmallIntegerField + case reflect.TypeOf(new(int32)), + reflect.TypeOf(new(int)): + ft = TypeIntegerField + case reflect.TypeOf(new(int64)): + ft = TypeBigIntegerField + case reflect.TypeOf(new(uint8)): + ft = TypePositiveBitField + case reflect.TypeOf(new(uint16)): + ft = TypePositiveSmallIntegerField + case reflect.TypeOf(new(uint32)), + reflect.TypeOf(new(uint)): + ft = TypePositiveIntegerField + case reflect.TypeOf(new(uint64)): + ft = TypePositiveBigIntegerField + case reflect.TypeOf(new(float32)), + reflect.TypeOf(new(float64)): + ft = TypeFloatField + case reflect.TypeOf(new(bool)): + ft = TypeBooleanField + case reflect.TypeOf(new(string)): + ft = TypeVarCharField + case reflect.TypeOf(new(time.Time)): + ft = TypeDateTimeField + default: + elm := reflect.Indirect(val) + switch elm.Kind() { + case reflect.Int8: + ft = TypeBitField + case reflect.Int16: + ft = TypeSmallIntegerField + case reflect.Int32, reflect.Int: + ft = TypeIntegerField + case reflect.Int64: + ft = TypeBigIntegerField + case reflect.Uint8: + ft = TypePositiveBitField + case reflect.Uint16: + ft = TypePositiveSmallIntegerField + case reflect.Uint32, reflect.Uint: + ft = TypePositiveIntegerField + case reflect.Uint64: + ft = TypePositiveBigIntegerField + case reflect.Float32, reflect.Float64: + ft = TypeFloatField + case reflect.Bool: + ft = TypeBooleanField + case reflect.String: + ft = TypeVarCharField + default: + if elm.Interface() == nil { + panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) + } + switch elm.Interface().(type) { + case sql.NullInt64: + ft = TypeBigIntegerField + case sql.NullFloat64: + ft = TypeFloatField + case sql.NullBool: + ft = TypeBooleanField + case sql.NullString: + ft = TypeVarCharField + case time.Time: + ft = TypeDateTimeField + } + } + } + if ft&IsFieldType == 0 { + err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) + } + return +} + +// parse struct tag string +func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { + attrs = make(map[string]bool) + tags = make(map[string]string) + for _, v := range strings.Split(data, defaultStructTagDelim) { + if v == "" { + continue + } + v = strings.TrimSpace(v) + if t := strings.ToLower(v); supportTag[t] == 1 { + attrs[t] = true + } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { + name := t[:i] + if supportTag[name] == 2 { + v = v[i+1 : len(v)-1] + tags[name] = v + } + } else { + DebugLog.Println("unsupport orm tag", v) + } + } + return +} diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go new file mode 100644 index 00000000..0551b1cd --- /dev/null +++ b/pkg/orm/orm.go @@ -0,0 +1,579 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "reflect" + "sync" + "time" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = false + DebugLog = NewLog(os.Stdout) + DefaultRowsLimit = -1 + DefaultRelsDepth = 2 + DefaultTimeLoc = time.Local + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +// Params stores the Params +type Params map[string]interface{} + +// ParamsList stores paramslist +type ParamsList []interface{} + +type orm struct { + alias *alias + db dbQuerier + isTx bool +} + +var _ Ormer = new(orm) + +// get model info and model reflect value +func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if needPtr && val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + name := getFullName(typ) + if mi, ok := modelCache.getByFullName(name); ok { + return mi, ind + } + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) +} + +// get field info from model info by given field name +func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { + fi, ok := mi.fields.GetByAny(name) + if !ok { + panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) + } + return fi +} + +// read data to model +func (o *orm) Read(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + cols = append([]string{col1}, cols...) + mi, ind := o.getMiInd(md, true) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + if err == ErrNoRows { + // Create + id, err := o.Insert(md) + return (err == nil), id, err + } + + id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else if mi.fields.pk.rel { + return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + } else { + id = vid.Int() + } + + return false, id, err +} + +// insert model data to database +func (o *orm) Insert(md interface{}) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// set auto pk field +func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) + } + } +} + +// insert some models to database +func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := reflect.Indirect(sind.Index(i)) + mi, _ := o.getMiInd(ind.Interface(), false) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt++ + } + } else { + mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + } + return cnt, nil +} + +// InsertOrUpdate data to database +func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// update model to database. +// cols set the columns those want to update. +func (o *orm) Update(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + if err != nil { + return num, err + } + if num > 0 { + o.setPk(mi, ind, 0) + } + return num, nil +} + +// create a models to models queryer +func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + switch { + case fi.fieldType == RelManyToMany: + case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: + default: + panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) + } + + return newQueryM2M(md, o, mi, fi, ind) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + _, fi, ind, qseter := o.queryRelated(md, name) + + qs := qseter.(*querySet) + + var relDepth int + var limit, offset int64 + var order string + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + relDepth = DefaultRelsDepth + } + } else if v, ok := arg.(int); ok { + relDepth = v + } + case 1: + limit = ToInt64(arg) + case 2: + offset = ToInt64(arg) + case 3: + order, _ = arg.(string) + } + } + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + limit = 1 + offset = 0 + } + + qs.limit = limit + qs.offset = offset + qs.relDepth = relDepth + + if len(order) > 0 { + qs.orders = []string{order} + } + + find := ind.FieldByIndex(fi.fieldIndex) + + var nums int64 + var err error + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + val := reflect.New(find.Type().Elem()) + container := val.Interface() + err = qs.One(container) + if err == nil { + find.Set(val) + nums = 1 + } + default: + nums, err = qs.All(find.Addr().Interface()) + } + + return nums, err +} + +// return a QuerySeter for related models to md model. +// it can do all, update, delete in QuerySeter. +// example: +// qs := orm.QueryRelated(post,"Tag") +// qs.All(&[]*Tag{}) +// +func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { + // is this api needed ? + _, _, _, qs := o.queryRelated(md, name) + return qs +} + +// get QuerySeter for related models to md model +func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + _, _, exist := getExistPk(mi, ind) + if !exist { + panic(ErrMissPK) + } + + var qs *querySet + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + if !fi.inModel { + break + } + qs = o.getRelQs(md, mi, fi) + case RelReverseOne, RelReverseMany: + if !fi.inModel { + break + } + qs = o.getReverseQs(md, mi, fi) + } + + if qs == nil { + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) + } + + return mi, fi, ind, qs +} + +// get reverse relation QuerySeter +func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelReverseOne, RelReverseMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) + } + + var q *querySet + + if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { + q = newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// get relation QuerySeter +func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) + } + + q := newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition() + + if fi.fieldType == RelManyToMany { + q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q.cond = q.cond.And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + var name string + if table, ok := ptrStructOrTableName.(string); ok { + name = nameStrategyMap[defaultNameStrategy](table) + if mi, ok := modelCache.get(name); ok { + qs = newQuerySet(o, mi) + } + } else { + name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + if mi, ok := modelCache.getByFullName(name); ok { + qs = newQuerySet(o, mi) + } + } + if qs == nil { + panic(fmt.Errorf(" table name: `%s` not exists", name)) + } + return +} + +// switch to another registered database driver by given name. +func (o *orm) Using(name string) error { + if o.isTx { + panic(fmt.Errorf(" transaction has been start, cannot change db")) + } + if al, ok := dataBaseCache.get(name); ok { + o.alias = al + if Debug { + o.db = newDbQueryLog(al, al.DB) + } else { + o.db = al.DB + } + } else { + return fmt.Errorf(" unknown db alias name `%s`", name) + } + return nil +} + +// begin transaction +func (o *orm) Begin() error { + return o.BeginTx(context.Background(), nil) +} + +func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + var tx *sql.Tx + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return err + } + o.isTx = true + if Debug { + o.db.(*dbQueryLog).SetDB(tx) + } else { + o.db = tx + } + return nil +} + +// commit transaction +func (o *orm) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Commit() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *orm) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Rollback() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *orm) Raw(query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *orm) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *orm) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats + } + return nil +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + err := o.Using("default") + if err != nil { + panic(err) + } + return o +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + var al *alias + + if dr, ok := drivers[driverName]; ok { + al = new(alias) + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + detectTZ(al) + + o := new(orm) + o.alias = al + + if Debug { + o.db = newDbQueryLog(o.alias, db) + } else { + o.db = db + } + + return o, nil +} diff --git a/pkg/orm/orm_conds.go b/pkg/orm/orm_conds.go new file mode 100644 index 00000000..f3fd66f0 --- /dev/null +++ b/pkg/orm/orm_conds.go @@ -0,0 +1,153 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +type condValue struct { + exprs []string + args []interface{} + cond *Condition + isOr bool + isNot bool + isCond bool + isRaw bool + sql string +} + +// Condition struct. +// work for WHERE conditions. +type Condition struct { + params []condValue +} + +// NewCondition return new condition struct +func NewCondition() *Condition { + c := &Condition{} + return c +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + if len(sql) == 0 { + panic(fmt.Errorf(" sql cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) + return &c +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) + return &c +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) + return &c +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true}) + } + return c +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) + } + return c +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) + return &c +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) + return &c +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) + } + return c +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) + } + return c +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return len(c.params) == 0 +} + +// clone clone a condition +func (c Condition) clone() *Condition { + return &c +} diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go new file mode 100644 index 00000000..f107bb59 --- /dev/null +++ b/pkg/orm/orm_log.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + "io" + "log" + "strings" + "time" +) + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +//costomer log func +var LogFunc func(query map[string]interface{}) + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + d := new(Log) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) + return d +} + +func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { + var logMap = make(map[string]interface{}) + sub := time.Now().Sub(t) / 1e5 + elsp := float64(int(sub)) / 10.0 + logMap["cost_time"] = elsp + flag := " OK" + if err != nil { + flag = "FAIL" + } + logMap["flag"] = flag + con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) + cons := make([]string, 0, len(args)) + for _, arg := range args { + cons = append(cons, fmt.Sprintf("%v", arg)) + } + if len(cons) > 0 { + con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) + } + if err != nil { + con += " - " + err.Error() + } + logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) + if LogFunc != nil{ + LogFunc(logMap) + } + DebugLog.Println(con) +} + +// statement query logger struct. +// if dev mode, use stmtQueryLog, or use stmtQuerier. +type stmtQueryLog struct { + alias *alias + query string + stmt stmtQuerier +} + +var _ stmtQuerier = new(stmtQueryLog) + +func (d *stmtQueryLog) Close() error { + a := time.Now() + err := d.stmt.Close() + debugLogQueies(d.alias, "st.Close", d.query, a, err) + return err +} + +func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.stmt.Exec(args...) + debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.stmt.Query(args...) + debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + a := time.Now() + res := d.stmt.QueryRow(args...) + debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) + return res +} + +func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { + d := new(stmtQueryLog) + d.stmt = stmt + d.alias = alias + d.query = query + return d +} + +// database query logger struct. +// if dev mode, use dbQueryLog, or use dbQuerier. +type dbQueryLog struct { + alias *alias + db dbQuerier + tx txer + txe txEnder +} + +var _ dbQuerier = new(dbQueryLog) +var _ txer = new(dbQueryLog) +var _ txEnder = new(dbQueryLog) + +func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.Prepare(query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.PrepareContext(ctx, query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.Exec(query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.ExecContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.Query(query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.QueryContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRow(query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRowContext(ctx, query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) Begin() (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).Begin() + debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).BeginTx(ctx, opts) + debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) Commit() error { + a := time.Now() + err := d.db.(txEnder).Commit() + debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) + return err +} + +func (d *dbQueryLog) Rollback() error { + a := time.Now() + err := d.db.(txEnder).Rollback() + debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) + return err +} + +func (d *dbQueryLog) SetDB(db dbQuerier) { + d.db = db +} + +func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { + d := new(dbQueryLog) + d.alias = alias + d.db = db + return d +} diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go new file mode 100644 index 00000000..de3181ce --- /dev/null +++ b/pkg/orm/orm_object.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" +) + +// an insert queryer struct +type insertSet struct { + mi *modelInfo + orm *orm + stmt stmtQuerier + closed bool +} + +var _ Inserter = new(insertSet) + +// insert model ignore it's registered or not. +func (o *insertSet) Insert(md interface{}) (int64, error) { + if o.closed { + return 0, ErrStmtClosed + } + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := getFullName(typ) + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) + } + if name != o.mi.fullName { + panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) + } + id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + if err != nil { + return id, err + } + if id > 0 { + if o.mi.fields.pk.auto { + if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) + } + } + } + return id, nil +} + +// close insert queryer statement +func (o *insertSet) Close() error { + if o.closed { + return ErrStmtClosed + } + o.closed = true + return o.stmt.Close() +} + +// create new insert queryer. +func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { + bi := new(insertSet) + bi.orm = orm + bi.mi = mi + st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + if err != nil { + return nil, err + } + if Debug { + bi.stmt = newStmtQueryLog(orm.alias, st, query) + } else { + bi.stmt = st + } + return bi, nil +} diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go new file mode 100644 index 00000000..6a270a0d --- /dev/null +++ b/pkg/orm/orm_querym2m.go @@ -0,0 +1,140 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "reflect" + +// model to model struct +type queryM2M struct { + md interface{} + mi *modelInfo + fi *fieldInfo + qs *querySet + ind reflect.Value +} + +// add models to origin models when creating queryM2M. +// example: +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} +// +// make sure the relation is defined in post model struct tag. +func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + fi := o.fi + mi := fi.relThroughModelInfo + mfi := fi.reverseFieldInfo + rfi := fi.reverseFieldInfoTwo + + orm := o.qs.orm + dbase := orm.alias.DbBaser + + var models []interface{} + var otherValues []interface{} + var otherNames []string + + for _, colname := range mi.fields.dbcols { + if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && + mi.fields.columns[colname] != mi.fields.pk { + otherNames = append(otherNames, colname) + } + } + for i, md := range mds { + if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { + otherValues = append(otherValues, md) + mds = append(mds[:i], mds[i+1:]...) + } + } + for _, md := range mds { + val := reflect.ValueOf(md) + if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + if v.CanInterface() { + models = append(models, v.Interface()) + } + } + } else { + models = append(models, md) + } + } + + _, v1, exist := getExistPk(o.mi, o.ind) + if !exist { + panic(ErrMissPK) + } + + names := []string{mfi.column, rfi.column} + + values := make([]interface{}, 0, len(models)*2) + for _, md := range models { + + ind := reflect.Indirect(reflect.ValueOf(md)) + var v2 interface{} + if ind.Kind() != reflect.Struct { + v2 = ind.Interface() + } else { + _, v2, exist = getExistPk(fi.relModelInfo, ind) + if !exist { + panic(ErrMissPK) + } + } + values = append(values, v1, v2) + + } + names = append(names, otherNames...) + values = append(values, otherValues...) + return dbase.InsertValue(orm.db, mi, true, names, values) +} + +// remove models following the origin model relationship +func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + fi := o.fi + qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) + + return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() +} + +// check model is existed in relationship of origin model +func (o *queryM2M) Exist(md interface{}) bool { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md). + Filter(fi.reverseFieldInfoTwo.name, md).Exist() +} + +// clean all models in related of origin model +func (o *queryM2M) Clear() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() +} + +// count all related models of origin model +func (o *queryM2M) Count() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() +} + +var _ QueryM2Mer = new(queryM2M) + +// create new M2M queryer. +func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { + qm2m := new(queryM2M) + qm2m.md = md + qm2m.mi = mi + qm2m.fi = fi + qm2m.ind = ind + qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + return qm2m +} diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go new file mode 100644 index 00000000..878b836b --- /dev/null +++ b/pkg/orm/orm_queryset.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "fmt" +) + +type colValue struct { + value int64 + opt operator +} + +type operator int + +// define Col operations +const ( + ColAdd operator = iota + ColMinus + ColMultiply + ColExcept + ColBitAnd + ColBitRShift + ColBitLShift + ColBitXOR + ColBitOr +) + +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } +func ColValue(opt operator, value interface{}) interface{} { + switch opt { + case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, + ColBitLShift, ColBitXOR, ColBitOr: + default: + panic(fmt.Errorf("orm.ColValue wrong operator")) + } + v, err := StrTo(ToStr(value)).Int64() + if err != nil { + panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) + } + var val colValue + val.value = v + val.opt = opt + return val +} + +// real query struct +type querySet struct { + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []string + distinct bool + forupdate bool + orm *orm + ctx context.Context + forContext bool +} + +var _ QuerySeter = new(querySet) + +// add condition expression to QuerySeter. +func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.And(expr, args...) + return &o +} + +// add raw sql to querySeter. +func (o querySet) FilterRaw(expr string, sql string) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.Raw(expr, sql) + return &o +} + +// add NOT condition to querySeter. +func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.AndNot(expr, args...) + return &o +} + +// set offset number +func (o *querySet) setOffset(num interface{}) { + o.offset = ToInt64(num) +} + +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. +func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { + o.limit = ToInt64(limit) + if len(args) > 0 { + o.setOffset(args[0]) + } + return &o +} + +// add OFFSET value +func (o querySet) Offset(offset interface{}) QuerySeter { + o.setOffset(offset) + return &o +} + +// add GROUP expression +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + +// add ORDER expression. +// "column" means ASC, "-column" means DESC. +func (o querySet) OrderBy(exprs ...string) QuerySeter { + o.orders = exprs + return &o +} + +// add DISTINCT to SELECT +func (o querySet) Distinct() QuerySeter { + o.distinct = true + return &o +} + +// add FOR UPDATE to SELECT +func (o querySet) ForUpdate() QuerySeter { + o.forupdate = true + return &o +} + +// set relation model to query together. +// it will query relation models and assign to parent model. +func (o querySet) RelatedSel(params ...interface{}) QuerySeter { + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + o.related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Errorf(" wrong param kind: %v", val)) + } + } + } + return &o +} + +// set condition to QuerySeter. +func (o querySet) SetCond(cond *Condition) QuerySeter { + o.cond = cond + return &o +} + +// get condition from QuerySeter +func (o querySet) GetCond() *Condition { + return o.cond +} + +// return QuerySeter execution result number +func (o *querySet) Count() (int64, error) { + return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// check result empty or not after QuerySeter executed +func (o *querySet) Exist() bool { + cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return cnt > 0 +} + +// execute update with parameters +func (o *querySet) Update(values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) +} + +// execute delete +func (o *querySet) Delete() (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// return a insert queryer. +// it can be used in times. +// example: +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) +func (o *querySet) PrepareInsert() (Inserter, error) { + return newInsertSet(o.orm, o.mi) +} + +// query all data and map to containers. +// cols means the columns when querying. +func (o *querySet) All(container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) +} + +// query one row data and map to containers. +// cols means the columns when querying. +func (o *querySet) One(container interface{}, cols ...string) error { + o.limit = 1 + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + if err != nil { + return err + } + if num == 0 { + return ErrNoRows + } + + if num > 1 { + return ErrMultiRows + } + return nil +} + +// query all data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. +func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to [][]interface +// it converts data to [][column_index]value +func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to []interface. +// it's designed for one row record set, auto change to []value, not [][column]value. +func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// set context to QuerySeter. +func (o querySet) WithContext(ctx context.Context) QuerySeter { + o.ctx = ctx + o.forContext = true + return &o +} + +// create new QuerySeter. +func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { + o := new(querySet) + o.mi = mi + o.orm = orm + return o +} diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go new file mode 100644 index 00000000..3325a7ea --- /dev/null +++ b/pkg/orm/orm_raw.go @@ -0,0 +1,867 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// raw sql string prepared statement +type rawPrepare struct { + rs *rawSet + stmt stmtQuerier + closed bool +} + +func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { + if o.closed { + return nil, ErrStmtClosed + } + return o.stmt.Exec(args...) +} + +func (o *rawPrepare) Close() error { + o.closed = true + return o.stmt.Close() +} + +func newRawPreparer(rs *rawSet) (RawPreparer, error) { + o := new(rawPrepare) + o.rs = rs + + query := rs.query + rs.orm.alias.DbBaser.ReplaceMarks(&query) + + st, err := rs.orm.db.Prepare(query) + if err != nil { + return nil, err + } + if Debug { + o.stmt = newStmtQueryLog(rs.orm.alias, st, query) + } else { + o.stmt = st + } + return o, nil +} + +// raw query seter +type rawSet struct { + query string + args []interface{} + orm *orm +} + +var _ RawSeter = new(rawSet) + +// set args for every query +func (o rawSet) SetArgs(args ...interface{}) RawSeter { + o.args = args + return &o +} + +// execute raw sql and return sql.Result +func (o *rawSet) Exec() (sql.Result, error) { + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + return o.orm.db.Exec(query, args...) +} + +// set field value to row container +func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { + switch ind.Kind() { + case reflect.Bool: + if value == nil { + ind.SetBool(false) + } else if v, ok := value.(bool); ok { + ind.SetBool(v) + } else { + v, _ := StrTo(ToStr(value)).Bool() + ind.SetBool(v) + } + + case reflect.String: + if value == nil { + ind.SetString("") + } else { + ind.SetString(ToStr(value)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == nil { + ind.SetInt(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetInt(val.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetInt(int64(val.Uint())) + default: + v, _ := StrTo(ToStr(value)).Int64() + ind.SetInt(v) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == nil { + ind.SetUint(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetUint(uint64(val.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetUint(val.Uint()) + default: + v, _ := StrTo(ToStr(value)).Uint64() + ind.SetUint(v) + } + } + case reflect.Float64, reflect.Float32: + if value == nil { + ind.SetFloat(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Float64: + ind.SetFloat(val.Float()) + default: + v, _ := StrTo(ToStr(value)).Float64() + ind.SetFloat(v) + } + } + + case reflect.Struct: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + return + } + switch ind.Interface().(type) { + case time.Time: + var str string + switch d := value.(type) { + case time.Time: + o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) + ind.Set(reflect.ValueOf(d)) + case []byte: + str = string(d) + case string: + str = d + } + if str != "" { + if len(str) >= 19 { + str = str[:19] + t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) + if err == nil { + t = t.In(DefaultTimeLoc) + ind.Set(reflect.ValueOf(t)) + } + } else if len(str) >= 10 { + str = str[:10] + t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } + } + } + case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: + indi := reflect.New(ind.Type()).Interface() + sc, ok := indi.(sql.Scanner) + if !ok { + return + } + err := sc.Scan(value) + if err == nil { + ind.Set(reflect.Indirect(reflect.ValueOf(sc))) + } + } + + case reflect.Ptr: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + break + } + ind.Set(reflect.New(ind.Type().Elem())) + o.setFieldValue(reflect.Indirect(ind), value) + } +} + +// set field value in loop for slice container +func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { + nInds := *nIndsPtr + + cur := 0 + for i := 0; i < len(sInds); i++ { + sInd := sInds[i] + eTyp := eTyps[i] + + typ := eTyp + isPtr := false + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + + var nInd reflect.Value + if init { + nInd = reflect.New(sInd.Type()).Elem() + } else { + nInd = nInds[i] + } + + val := reflect.New(typ) + ind := val.Elem() + + tpName := ind.Type().String() + + if ind.Kind() == reflect.Struct { + if tpName == "time.Time" { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + } else { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + if nInd.Kind() == reflect.Slice { + if isPtr { + nInd = reflect.Append(nInd, val) + } else { + nInd = reflect.Append(nInd, ind) + } + } else { + if isPtr { + nInd.Set(val) + } else { + nInd.Set(ind) + } + } + + nInds[i] = nInd + } +} + +// query data and map to container +func (o *rawSet) QueryRow(containers ...interface{}) error { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" all args must be use ptr")) + } + + etyp := ind.Type() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, ind) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + + defer rows.Close() + + if rows.Next() { + if structMode { + columns, err := rows.Columns() + if err != nil { + return err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return err + } + + ind := sInds[0] + + if ind.Kind() == reflect.Ptr { + if ind.IsNil() || !ind.IsValid() { + ind.Set(reflect.New(eTyps[0].Elem())) + } + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + } else { + if err := rows.Scan(refs...); err != nil { + return err + } + + nInds := make([]reflect.Value, len(sInds)) + o.loopSetRefs(refs, sInds, &nInds, eTyps, true) + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + + } else { + return ErrNoRows + } + + return nil +} + +// query data rows and map to container +func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + sInd := reflect.Indirect(val) + if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { + panic(fmt.Errorf(" all args must be use ptr slice")) + } + + etyp := sInd.Type().Elem() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, sInd) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rows.Close() + + var cnt int64 + nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + + for rows.Next() { + + if structMode { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 && !sInd.IsNil() { + sInd.Set(reflect.New(sInd.Type()).Elem()) + } + + var ind reflect.Value + if eTyps[0].Kind() == reflect.Ptr { + ind = reflect.New(eTyps[0].Elem()) + } else { + ind = reflect.New(eTyps[0]) + } + + if ind.Kind() == reflect.Ptr { + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + // init call the recursive function + recursiveSetField(ind) + } + + if eTyps[0].Kind() == reflect.Ptr { + ind = ind.Addr() + } + + sInd = reflect.Append(sInd, ind) + + } else { + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) + } + + cnt++ + } + + if cnt > 0 { + + if structMode { + sInds[0].Set(sInd) + } else { + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + } + + return cnt, nil +} + +func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch container.(type) { + case *[]Params: + typ = 1 + case *[]ParamsList: + typ = 2 + case *ParamsList: + typ = 3 + default: + panic(fmt.Errorf(" unsupport read values type `%T`", container)) + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + var rs *sql.Rows + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + indexs []int + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + if len(needCols) > 0 { + indexs = make([]int, 0, len(needCols)) + } else { + indexs = make([]int, 0, len(columns)) + } + + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + var ref sql.NullString + refs[i] = &ref + + if len(needCols) > 0 { + for _, c := range needCols { + if c == cols[i] { + indexs = append(indexs, i) + } + } + } else { + indexs = append(indexs, i) + } + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params[cols[i]] = value.String + } else { + params[cols[i]] = nil + } + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params = append(params, value.String) + } else { + params = append(params, nil) + } + } + lists = append(lists, params) + case 3: + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + list = append(list, value.String) + } else { + list = append(list, nil) + } + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { + var ( + maps Params + ind *reflect.Value + ) + + var typ int + switch container.(type) { + case *Params: + typ = 1 + default: + typ = 2 + vl := reflect.ValueOf(container) + id := reflect.Indirect(vl) + if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { + panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) + } + + ind = &id + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + ) + + var ( + keyIndex = -1 + valueIndex = -1 + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + if keyCol == cols[i] { + keyIndex = i + } + if typ == 1 || keyIndex == i { + var ref sql.NullString + refs[i] = &ref + } else { + var ref interface{} + refs[i] = &ref + } + if valueCol == cols[i] { + valueIndex = i + } + } + if keyIndex == -1 || valueIndex == -1 { + panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 { + switch typ { + case 1: + maps = make(Params) + } + } + + key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String + + switch typ { + case 1: + value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) + if value.Valid { + maps[key] = value.String + } else { + maps[key] = nil + } + + default: + if id := ind.FieldByName(camelString(key)); id.IsValid() { + o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) + } + } + + cnt++ + } + + if typ == 1 { + v, _ := container.(*Params) + *v = maps + } + + return cnt, nil +} + +// query data to []map[string]interface +func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to [][]interface +func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to []interface +func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(result, keyCol, valueCol) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(ptrStruct, keyCol, valueCol) +} + +// return prepared raw statement for used in times. +func (o *rawSet) Prepare() (RawPreparer, error) { + return newRawPreparer(o) +} + +func newRawSet(orm *orm, query string, args []interface{}) RawSeter { + o := new(rawSet) + o.query = query + o.args = args + o.orm = orm + return o +} diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go new file mode 100644 index 00000000..bdb430b6 --- /dev/null +++ b/pkg/orm/orm_test.go @@ -0,0 +1,2494 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package orm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" +) + +var _ = os.PathSeparator + +var ( + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" +) + +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { + if len(args) == 0 { + return false, fmt.Errorf("miss args") + } + b := args[0] + arg := argAny(args) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b + } else { + err = fmt.Errorf("compare datetime miss format") + goto wrongArg + } + } + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } else { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } + } + +wrongArg: + if err != nil { + return false, err + } + + return true, nil +} + +func AssertIs(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(true, a, args...); !ok { + return err + } + return nil +} + +func AssertNot(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(false, a, args...); !ok { + return err + } + return nil +} + +func getCaller(skip int) string { + pc, file, line, _ := runtime.Caller(skip) + fun := runtime.FuncForPC(pc) + _, fn := filepath.Split(file) + data, err := ioutil.ReadFile(file) + var codes []string + if err == nil { + lines := bytes.Split(data, []byte{'\n'}) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + if code != "" { + codes = append(codes, code) + } + } + } + funName := fun.Name() + if i := strings.LastIndex(funName, "."); i > -1 { + funName = funName[i+1:] + } + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) +} + +func throwFail(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.Fail() + } +} + +func throwFailNow(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.FailNow() + } +} + +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + +func TestSyncDb(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + err := RunSyncdb("default", true, Debug) + throwFail(t, err) + + modelCache.clean() +} + +func TestRegisterModels(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + BootStrap() + + dORM = NewOrm() + dDbBaser = getDbAlias("default").DbBaser +} + +func TestModelSyntax(t *testing.T) { + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFullName(fn) + throwFail(t, AssertIs(ok, true)) + + mi, ok = modelCache.get("user") + throwFail(t, AssertIs(ok, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + } +} + +var DataValues = map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), +} + +func TestDataTypes(t *testing.T) { + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + if name == "JSON" { + continue + } + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = Data{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + d = DataNull{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.JSON, data)) + + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{ID: 2} + err = dORM.Read(&d) + throwFail(t, err) + + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{String: "test", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{ID: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) + throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) + throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) +} + +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestCRUD(t *testing.T) { + profile := NewProfile() + profile.Age = 30 + profile.Money = 1234.12 + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + u := &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + user.UserName = "astaxie" + user.Profile = profile + num, err := dORM.Update(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) + + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) + + num, err = dORM.Delete(profile) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + throwFail(t, AssertIs(true, u.Profile == nil)) + + num, err = dORM.Delete(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: 100} + err = dORM.Read(u) + throwFail(t, AssertIs(err, ErrNoRows)) + + ub := UserBig{} + ub.Name = "name" + id, err = dORM.Insert(&ub) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + ub = UserBig{ID: 1} + err = dORM.Read(&ub) + throwFail(t, err) + throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertTestData(t *testing.T) { + var users []*User + + profile := NewProfile() + profile.Age = 28 + profile.Money = 1234.12 + + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 1 + user.IsStaff = false + user.IsActive = true + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + profile = NewProfile() + profile.Age = 30 + profile.Money = 4321.09 + + id, err = dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "astaxie" + user.Email = "astaxie@gmail.com" + user.Password = "password" + user.Status = 2 + user.IsStaff = true + user.IsActive = false + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "nobody" + user.Email = "nobody@gmail.com" + user.Password = "nobody" + user.Status = 3 + user.IsStaff = false + user.IsActive = false + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 4)) + + tags := []*Tag{ + {Name: "golang", BestPost: &Post{ID: 2}}, + {Name: "example"}, + {Name: "format"}, + {Name: "c++"}, + } + + posts := []*Post{ + {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. +With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, + {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. +The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, + } + + comments := []*Comment{ + {Post: posts[0], Content: "a comment"}, + {Post: posts[1], Content: "yes"}, + {Post: posts[1]}, + {Post: posts[1]}, + {Post: posts[2]}, + {Post: posts[2]}, + } + + for _, tag := range tags { + id, err := dORM.Insert(tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, post := range posts { + id, err := dORM.Insert(post) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(post.Tags) + if num > 0 { + nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + + for _, comment := range comments { + id, err := dORM.Insert(comment) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + +} + +func TestCustomField(t *testing.T) { + user := User{ID: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + user.Extra.Name = "beego" + user.Extra.Data = "orm" + _, err = dORM.Update(&user, "Langs", "Extra") + throwFailNow(t, err) + + user = User{ID: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) + + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) +} + +func TestExpr(t *testing.T) { + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable((*User)(nil)) + qs = dORM.QueryTable("User") + qs = dORM.QueryTable("user") + num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + // throwFail(t, err) + // throwFail(t, AssertIs(num, 3)) +} + +func TestOperators(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__iexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__contains", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + var shouldNum int + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__contains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gt", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gte", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("status__lt", Uint(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__lte", Int(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("user_name__startswith", "s").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsSqlite || IsTidb { + shouldNum = 1 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__startswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__istartswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__endswith", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__endswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__iendswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("status__in", 1, 2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSetCond(t *testing.T) { + cond := NewCondition() + cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + + qs := dORM.QueryTable("user") + num, err := qs.SetCond(cond1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) +} + +func TestLimit(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Limit(-1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + num, err = qs.Limit(-1, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Limit(0, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOffset(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOrderBy(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestAll(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("Id").All(&users) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) + + var users2 []User + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").All(&users2) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + var users3 []*User + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + throwFailNow(t, AssertIs(users3 == nil, false)) +} + +func TestOne(t *testing.T) { + var user User + qs := dORM.QueryTable("user") + err := qs.One(&user) + throwFail(t, err) + + user = User{} + err = qs.OrderBy("Id").Limit(1).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "slene")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + user = User{} + err = qs.OrderBy("-Id").Limit(100).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "nobody")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + err = qs.Filter("user_name", "nothing").One(&user) + throwFail(t, AssertIs(err, ErrNoRows)) + +} + +func TestValues(t *testing.T) { + var maps []Params + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) + } + + num, err = qs.Filter("UserName", "slene").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestValuesList(t *testing.T) { + var list []ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").ValuesList(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) + } + + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) + } +} + +func TestValuesFlat(t *testing.T) { + var list ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) + } +} + +func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } + qs := dORM.QueryTable("user") + num, err := qs.Filter("profile__age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var user User + err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "slene").RelatedSel().One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) + + qs = dORM.QueryTable("user_profile") + num, err = qs.Filter("user__username", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var posts []*Post + qs = dORM.QueryTable("post") + num, err = qs.RelatedSel().All(&posts) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 4)) + + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) +} + +func TestReverseQuery(t *testing.T) { + var profile Profile + err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + profile = Profile{} + err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + var user User + err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + var posts []*Post + num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). + Filter("User__UserName", "slene").RelatedSel().All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].User == nil, false)) + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + + var tags []*Tag + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) + throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) +} + +func TestLoadRelated(t *testing.T) { + // load reverse foreign key + user := User{ID: 3} + + err := dORM.Read(&user) + throwFailNow(t, err) + + num, err := dORM.LoadRelated(&user, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) + + num, err = dORM.LoadRelated(&user, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + + num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + // load reverse one to one + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} + num, err = dORM.Update(&profile, "BestPost") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + err = dORM.Read(&profile) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&profile, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&profile, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) + + // load rel one to one + err = dORM.Read(&user) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&user, "Profile") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + num, err = dORM.LoadRelated(&user, "Profile", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) + throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) + + post := Post{ID: 2} + + // load rel foreign key + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&post, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(post.User.Profile == nil, false)) + throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) + + // load rel m2m + post = Post{ID: 2} + + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "Tags") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + + num, err = dORM.LoadRelated(&post, "Tags", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) + + // load reverse m2m + tag := Tag{ID: 1} + + err = dORM.Read(&tag) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&tag, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) + + num, err = dORM.LoadRelated(&tag, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) +} + +func TestQueryM2M(t *testing.T) { + post := Post{ID: 4} + m2m := dORM.QueryM2M(&post, "Tags") + + tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} + tag2 := &Tag{Name: "TestTag3"} + tag3 := []interface{}{&Tag{Name: "TestTag4"}} + + tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} + + for _, tag := range tags { + _, err := dORM.Insert(tag) + throwFailNow(t, err) + } + + num, err := m2m.Add(tag1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 5)) + + num, err = m2m.Remove(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + exist := m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + tag := Tag{Name: "test"} + _, err = dORM.Insert(&tag) + throwFailNow(t, err) + + m2m = dORM.QueryM2M(&tag, "Posts") + + post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} + post2 := &Post{Title: "TestPost3"} + post3 := []interface{}{&Post{Title: "TestPost4"}} + + posts := []interface{}{post1[0], post1[1], post2, post3[0]} + + for _, post := range posts { + p := post.(*Post) + p.User = &User{ID: 1} + _, err := dORM.Insert(post) + throwFailNow(t, err) + } + + num, err = m2m.Add(post1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + num, err = m2m.Remove(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + num, err = dORM.Delete(&tag) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) +} + +func TestQueryRelate(t *testing.T) { + // post := &Post{Id: 2} + + // qs := dORM.QueryRelate(post, "Tags") + // num, err := qs.Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + + // var tags []*Tag + // num, err = qs.All(&tags) + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + // throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) +} + +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + +func TestPrepareInsert(t *testing.T) { + qs := dORM.QueryTable("user") + i, err := qs.PrepareInsert() + throwFailNow(t, err) + + var user User + user.UserName = "testing1" + num, err := i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + user.UserName = "testing2" + num, err = i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + err = i.Close() + throwFail(t, err) + err = i.Close() + throwFail(t, AssertIs(err, ErrStmtClosed)) +} + +func TestRawExec(t *testing.T) { + Q := dDbBaser.TableQuote() + + query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) + res, err := dORM.Raw(query, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) + + res, err = dORM.Raw(query, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) +} + +func TestRawQueryRow(t *testing.T) { + var ( + Boolean bool + Char string + Text string + Time time.Time + Date time.Time + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 + ) + + dataValues := make(map[string]interface{}, len(DataValues)) + + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v + } + + Q := dDbBaser.TableQuote() + + cols := []string{ + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", + } + sep := fmt.Sprintf("%s, %s", Q, Q) + query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) + var id int + values := []interface{}{ + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, + } + err := dORM.Raw(query, 1).QueryRow(values...) + throwFailNow(t, err) + for i, col := range cols { + vu := values[i] + v := reflect.ValueOf(vu).Elem().Interface() + switch col { + case "id": + throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) + case "date": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) + case "datetime": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) + default: + throwFail(t, AssertIs(v, dataValues[col])) + } + } + + var ( + uid int + status *int + pid *int + ) + + cols = []string{ + "id", "Status", "profile_id", + } + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) + throwFail(t, err) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + +func TestQueryRows(t *testing.T) { + Q := dDbBaser.TableQuote() + + var datas []*Data + + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + + ind := reflect.Indirect(reflect.ValueOf(datas[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var ids []int + var usernames []string + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) + + //test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +func TestRawValues(t *testing.T) { + Q := dDbBaser.TableQuote() + + var maps []Params + query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) + num, err := dORM.Raw(query, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(query, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], "slene")) + } + + query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) + var list ParamsList + num, err = dORM.Raw(query).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) + } +} + +func TestRawPrepare(t *testing.T) { + switch { + case IsMysql || IsSqlite: + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid > 0, true)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + + case IsPostgres: + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + } + } +} + +func TestUpdate(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ + "is_staff": true, + "is_active": true, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColAdd, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMinus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMultiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColExcept, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) +} + +func TestDelete(t *testing.T) { + qs := dORM.QueryTable("user_profile") + num, err := qs.Filter("user__user_name", "slene").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestTransaction(t *testing.T) { + // this test worked when database support transaction + + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name__in", names).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + err = o.Begin() + throwFail(t, err) + + tag.Name = "commit" + id, err = o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + +} + +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestReadOrCreate(t *testing.T) { + u := &User{ + UserName: "Kyle", + Email: "kylemcc@gmail.com", + Password: "other_pass", + Status: 7, + IsStaff: false, + IsActive: true, + } + + created, pk, err := dORM.ReadOrCreate(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) + throwFail(t, AssertIs(u.UserName, "Kyle")) + throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) + throwFail(t, AssertIs(u.Password, "other_pass")) + throwFail(t, AssertIs(u.Status, 7)) + throwFail(t, AssertIs(u.IsStaff, false)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) + + nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} + created, pk, err = dORM.ReadOrCreate(nu, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.UserName, u.UserName)) + throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above + throwFail(t, AssertIs(nu.Password, u.Password)) + throwFail(t, AssertIs(nu.Status, u.Status)) + throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) + throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + + dORM.Delete(u) +} + +func TestInLine(t *testing.T) { + name := "inline" + email := "hello@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + il := NewInLine() + il.ID = 1 + err = dORM.Read(il) + throwFail(t, err) + + throwFail(t, AssertIs(il.Name, name)) + throwFail(t, AssertIs(il.Email, email)) + throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) +} + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, _, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err := dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "i_d", + "NO": "n_o", + "NOO": "n_o_o", + "NOOooOOoo": "n_o_ooo_o_ooo", + "OrderNO": "order_n_o", + "tagName": "tag_name", + "tag_Name": "tag__name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} + +func TestInsertOrUpdate(t *testing.T) { + RegisterModel(new(User)) + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + test := User{UserName: "unique_username133"} + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + //test1 + _, err := dORM.InsertOrUpdate(&user1, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + } + //test2 + _, err = dORM.InsertOrUpdate(&user2, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } + //test3 + + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + //test4 - + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + //test5 * + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + //test6 / + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/pkg/orm/qb.go b/pkg/orm/qb.go new file mode 100644 index 00000000..e0655a17 --- /dev/null +++ b/pkg/orm/qb.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "errors" + +// QueryBuilder is the Query builder interface +type QueryBuilder interface { + Select(fields ...string) QueryBuilder + ForUpdate() QueryBuilder + From(tables ...string) QueryBuilder + InnerJoin(table string) QueryBuilder + LeftJoin(table string) QueryBuilder + RightJoin(table string) QueryBuilder + On(cond string) QueryBuilder + Where(cond string) QueryBuilder + And(cond string) QueryBuilder + Or(cond string) QueryBuilder + In(vals ...string) QueryBuilder + OrderBy(fields ...string) QueryBuilder + Asc() QueryBuilder + Desc() QueryBuilder + Limit(limit int) QueryBuilder + Offset(offset int) QueryBuilder + GroupBy(fields ...string) QueryBuilder + Having(cond string) QueryBuilder + Update(tables ...string) QueryBuilder + Set(kv ...string) QueryBuilder + Delete(tables ...string) QueryBuilder + InsertInto(table string, fields ...string) QueryBuilder + Values(vals ...string) QueryBuilder + Subquery(sub string, alias string) string + String() string +} + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + if driver == "mysql" { + qb = new(MySQLQueryBuilder) + } else if driver == "tidb" { + qb = new(TiDBQueryBuilder) + } else if driver == "postgres" { + err = errors.New("postgres query builder is not supported yet") + } else if driver == "sqlite" { + err = errors.New("sqlite query builder is not supported yet") + } else { + err = errors.New("unknown driver for query builder") + } + return +} diff --git a/pkg/orm/qb_mysql.go b/pkg/orm/qb_mysql.go new file mode 100644 index 00000000..23bdc9ee --- /dev/null +++ b/pkg/orm/qb_mysql.go @@ -0,0 +1,185 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// CommaSpace is the separation +const CommaSpace = ", " + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/qb_tidb.go b/pkg/orm/qb_tidb.go new file mode 100644 index 00000000..87b3ae84 --- /dev/null +++ b/pkg/orm/qb_tidb.go @@ -0,0 +1,182 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/types.go b/pkg/orm/types.go new file mode 100644 index 00000000..2fd10774 --- /dev/null +++ b/pkg/orm/types.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" +) + +// Driver define database driver +type Driver interface { + Name() string + Type() DriverType +} + +// Fielder define field info +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} +} + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + //args[0] bool true useDefaultRelsDepth ; false depth 0 + //args[0] int loadRelationDepth + //args[1] int limit default limit 1000 + //args[2] int offset default offset 0 + //args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter interface { + Insert(interface{}) (int64, error) + Close() error +} + +// QuerySeter query seter +type QuerySeter interface { + // add condition expression to QuerySeter. + // for example: + // filter by UserName == 'slene' + // qs.Filter("UserName", "slene") + // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 + // Filter("profile__Age", 28) + // // time compare + // qs.Filter("created", time.Now()) + Filter(string, ...interface{}) QuerySeter + // add raw sql to querySeter. + // for example: + // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") + // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) + FilterRaw(string, string) QuerySeter + // add NOT condition to querySeter. + // have the same usage as Filter + Exclude(string, ...interface{}) QuerySeter + // set condition to QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond1).Count() + SetCond(*Condition) QuerySeter + // get condition from QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) + // qs = qs.SetCond(cond) + // cond = qs.GetCond() + // cond := cond.Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond).Count() + GetCond() *Condition + // add LIMIT value. + // args[0] means offset, e.g. LIMIT num,offset. + // if Limit <= 0 then Limit will be set to default limit ,eg 1000 + // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 + // for example: + // qs.Limit(10, 2) + // // sql-> limit 10 offset 2 + Limit(limit interface{}, args ...interface{}) QuerySeter + // add OFFSET value + // same as Limit function's args[0] + Offset(offset interface{}) QuerySeter + // add GROUP BY expression + // for example: + // qs.GroupBy("id") + GroupBy(exprs ...string) QuerySeter + // add ORDER expression. + // "column" means ASC, "-column" means DESC. + // for example: + // qs.OrderBy("-status") + OrderBy(exprs ...string) QuerySeter + // set relation model to query together. + // it will query relation models and assign to parent model. + // for example: + // // will load all related fields use left join . + // qs.RelatedSel().One(&user) + // // will load related field only profile + // qs.RelatedSel("profile").One(&user) + // user.Profile.Age = 32 + RelatedSel(params ...interface{}) QuerySeter + // Set Distinct + // for example: + // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). + // Distinct(). + // All(&permissions) + Distinct() QuerySeter + // set FOR UPDATE to query. + // for example: + // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) + ForUpdate() QuerySeter + // return QuerySeter execution result number + // for example: + // num, err = qs.Filter("profile__age__gt", 28).Count() + Count() (int64, error) + // check result empty or not after QuerySeter executed + // the same as QuerySeter.Count > 0 + Exist() bool + // execute update with parameters + // for example: + // num, err = qs.Filter("user_name", "slene").Update(Params{ + // "Nums": ColValue(Col_Minus, 50), + // }) // user slene's Nums will minus 50 + // num, err = qs.Filter("UserName", "slene").Update(Params{ + // "user_name": "slene2" + // }) // user slene's name will change to slene2 + Update(values Params) (int64, error) + // delete from table + //for example: + // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + // //delete two user who's name is testing1 or testing2 + Delete() (int64, error) + // return a insert queryer. + // it can be used in times. + // example: + // i,err := sq.PrepareInsert() + // num, err = i.Insert(&user1) // user table will add one record user1 at once + // num, err = i.Insert(&user2) // user table will add one record user2 at once + // err = i.Close() //don't forget call Close + PrepareInsert() (Inserter, error) + // query all data and map to containers. + // cols means the columns when querying. + // for example: + // var users []*User + // qs.All(&users) // users[0],users[1],users[2] ... + All(container interface{}, cols ...string) (int64, error) + // query one row data and map to containers. + // cols means the columns when querying. + // for example: + // var user User + // qs.One(&user) //user.UserName == "slene" + One(container interface{}, cols ...string) error + // query all data and map to []map[string]interface. + // expres means condition expression. + // it converts data to []map[column]value. + // for example: + // var maps []Params + // qs.Values(&maps) //maps[0]["UserName"]=="slene" + Values(results *[]Params, exprs ...string) (int64, error) + // query all data and map to [][]interface + // it converts data to [][column_index]value + // for example: + // var list []ParamsList + // qs.ValuesList(&list) // list[0][1] == "slene" + ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + // query all data and map to []interface. + // it's designed for one column record set, auto change to []value, not [][column]value. + // for example: + // var list ParamsList + // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" + ValuesFlat(result *ParamsList, expr string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) +} + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer interface { + // add models to origin models when creating queryM2M. + // example: + // m2m := orm.QueryM2M(post,"Tag") + // m2m.Add(&Tag1{},&Tag2{}) + // for _,tag := range post.Tags{}{ ... } + // param could also be any of the follow + // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} + // &Tag{Id:5,Name: "TestTag3"} + // []interface{}{&Tag{Id:6,Name: "TestTag4"}} + // insert one or more rows to m2m table + // make sure the relation is defined in post model struct tag. + Add(...interface{}) (int64, error) + // remove models following the origin model relationship + // only delete rows from m2m table + // for example: + //tag3 := &Tag{Id:5,Name: "TestTag3"} + //num, err = m2m.Remove(tag3) + Remove(...interface{}) (int64, error) + // check model is existed in relationship of origin model + Exist(interface{}) bool + // clean all models in related of origin model + Clear() (int64, error) + // count all related models of origin model + Count() (int64, error) +} + +// RawPreparer raw query statement +type RawPreparer interface { + Exec(...interface{}) (sql.Result, error) + Close() error +} + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter interface { + //execute sql and get result + Exec() (sql.Result, error) + //query data and map to container + //for example: + // var name string + // var id int + // rs.QueryRow(&id,&name) // id==2 name=="slene" + QueryRow(containers ...interface{}) error + + // query data rows and map to container + // var ids []int + // var names []int + // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) + // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} + QueryRows(containers ...interface{}) (int64, error) + SetArgs(...interface{}) RawSeter + // query data to []map[string]interface + // see QuerySeter's Values + Values(container *[]Params, cols ...string) (int64, error) + // query data to [][]interface + // see QuerySeter's ValuesList + ValuesList(container *[]ParamsList, cols ...string) (int64, error) + // query data to []interface + // see QuerySeter's ValuesFlat + ValuesFlat(container *ParamsList, cols ...string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + + // return prepared raw statement for used in times. + // for example: + // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) + Prepare() (RawPreparer, error) +} + +// stmtQuerier statement querier +type stmtQuerier interface { + Close() error + Exec(args ...interface{}) (sql.Result, error) + //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + Query(args ...interface{}) (*sql.Rows, error) + //QueryContext(args ...interface{}) (*sql.Rows, error) + QueryRow(args ...interface{}) *sql.Row + //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row +} + +// db querier +type dbQuerier interface { + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// type DB interface { +// Begin() (*sql.Tx, error) +// Prepare(query string) (stmtQuerier, error) +// Exec(query string, args ...interface{}) (sql.Result, error) +// Query(query string, args ...interface{}) (*sql.Rows, error) +// QueryRow(query string, args ...interface{}) *sql.Row +// } + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { + Commit() error + Rollback() error +} + +// base database struct +type dbBaser interface { + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + SupportUpdateJoin() bool + UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + OperatorSQL(string) string + GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*fieldInfo, string, *string) + PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) + MaxLimit() uint64 + TableQuote() string + ReplaceMarks(*string) + HasReturningID(*modelInfo, *string) bool + TimeFromDB(*time.Time, *time.Location) + TimeToDB(*time.Time, *time.Location) + DbTypes() map[string]string + GetTables(dbQuerier) (map[string]bool, error) + GetColumns(dbQuerier, string) (map[string][3]string, error) + ShowTablesQuery() string + ShowColumnsQuery(string) string + IndexExists(dbQuerier, string, string) bool + collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error +} diff --git a/pkg/orm/utils.go b/pkg/orm/utils.go new file mode 100644 index 00000000..3ff76772 --- /dev/null +++ b/pkg/orm/utils.go @@ -0,0 +1,319 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + "time" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo string + +// Set string +func (f *StrTo) Set(v string) { + if v != "" { + *f = StrTo(v) + } else { + f.Clear() + } +} + +// Clear string +func (f *StrTo) Clear() { + *f = StrTo(0x1E) +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return string(f) != string(0x1E) +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +// Int string to int +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err +} + +// String string to string +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/orm/utils_test.go b/pkg/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} From b50fb44950a2687aa55652d0638d74a0e2967a6e Mon Sep 17 00:00:00 2001 From: playHing Date: Sat, 11 Jul 2020 02:00:48 +0800 Subject: [PATCH 025/207] Add bench test on context input query --- context/input_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/context/input_test.go b/context/input_test.go index db812a0f..3a6c2e7b 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -205,3 +205,13 @@ func TestParams(t *testing.T) { } } +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} From 55e6298f290345f900bff9443cb8fb9a09e78965 Mon Sep 17 00:00:00 2001 From: playHing Date: Sat, 11 Jul 2020 02:06:09 +0800 Subject: [PATCH 026/207] Fix concurrent form parsing and getting --- context/input.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/context/input.go b/context/input.go index 7b522c36..46a18f03 100644 --- a/context/input.go +++ b/context/input.go @@ -333,7 +333,11 @@ func (input *BeegoInput) Query(key string) string { return val } if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } } return input.Context.Request.Form.Get(key) } From 3e2c795410da9e912f9d4b46bf862a212bec27ca Mon Sep 17 00:00:00 2001 From: playHing Date: Mon, 13 Jul 2020 23:11:23 +0800 Subject: [PATCH 027/207] Rlock for form query --- context/input.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/context/input.go b/context/input.go index 46a18f03..385549c1 100644 --- a/context/input.go +++ b/context/input.go @@ -334,11 +334,13 @@ func (input *BeegoInput) Query(key string) string { } if input.Context.Request.Form == nil { input.dataLock.Lock() - defer input.dataLock.Unlock() if input.Context.Request.Form == nil { input.Context.Request.ParseForm() } + input.dataLock.Unlock() } + input.dataLock.RLock() + defer input.dataLock.RUnlock() return input.Context.Request.Form.Get(key) } From 192a278a2a7cd2d891d3ee25da1559702b1ec180 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 19 Jul 2020 12:56:58 +0000 Subject: [PATCH 028/207] Fix orm test when using Driver = mysql --- .travis.yml | 4 +- cache/redis/redis_test.go | 6 +- go.mod | 2 +- orm/cmd_utils.go | 10 +- orm/models_test.go | 497 -------- orm/orm_test.go | 2494 ------------------------------------- orm/utils_test.go | 70 -- pkg/orm/models_test.go | 20 +- pkg/orm/orm.go | 2 +- pkg/orm/orm_test.go | 32 +- 10 files changed, 43 insertions(+), 3094 deletions(-) delete mode 100644 orm/models_test.go delete mode 100644 orm/orm_test.go delete mode 100644 orm/utils_test.go diff --git a/.travis.yml b/.travis.yml index c019c999..26c3732e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - "1.13.x" + - "1.14.x" services: - redis-server - mysql @@ -63,7 +63,7 @@ after_script: - killall -w ssdb-server - rm -rf ./res/var/* script: - - go test -v ./... + - go test ./... - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 7ac88f87..60a19180 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -21,6 +21,7 @@ import ( "github.com/astaxie/beego/cache" "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" ) func TestRedisCache(t *testing.T) { @@ -124,9 +125,8 @@ func TestCache_Scan(t *testing.T) { if err != nil { t.Error("scan Error", err) } - if len(keys) != 10000 { - t.Error("scan all err") - } + + assert.Equal(t, 10000, len(keys), "scan all error") // clear all if err = bm.ClearAll(); err != nil { diff --git a/go.mod b/go.mod index ec500f51..adca28ad 100644 --- a/go.mod +++ b/go.mod @@ -37,4 +37,4 @@ replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/gol replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d -go 1.13 +go 1.14 diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 61f17346..eac85091 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - //if fi.initial.String() != "" { + // if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - //} + // } // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/orm/models_test.go b/orm/models_test.go deleted file mode 100644 index e3a635f2..00000000 --- a/orm/models_test.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "strings" - "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" -) - -// A slice string field. -type SliceStringField []string - -func (e SliceStringField) Value() []string { - return []string(e) -} - -func (e *SliceStringField) Set(d []string) { - *e = SliceStringField(d) -} - -func (e *SliceStringField) Add(v string) { - *e = append(*e, v) -} - -func (e *SliceStringField) String() string { - return strings.Join(e.Value(), ",") -} - -func (e *SliceStringField) FieldType() int { - return TypeVarCharField -} - -func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") - v := make([]string, 0, len(parts)) - for _, p := range parts { - v = append(v, strings.TrimSpace(p)) - } - e.Set(v) - } - default: - return fmt.Errorf(" unknown value `%v`", value) - } - return nil -} - -func (e *SliceStringField) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(SliceStringField) - -// A json field. -type JSONFieldTest struct { - Name string - Data string -} - -func (e *JSONFieldTest) String() string { - data, _ := json.Marshal(e) - return string(data) -} - -func (e *JSONFieldTest) FieldType() int { - return TypeTextField -} - -func (e *JSONFieldTest) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - return json.Unmarshal([]byte(d), e) - default: - return fmt.Errorf(" unknown value `%v`", value) - } -} - -func (e *JSONFieldTest) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(JSONFieldTest) - -type Data struct { - ID int `orm:"column(id)"` - Boolean bool - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - JSON string `orm:"type(json);default({\"name\":\"json\"})"` - Jsonb string `orm:"type(jsonb)"` - Time time.Time `orm:"type(time)"` - Date time.Time `orm:"type(date)"` - DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune - Int int - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 `orm:"digits(8);decimals(4)"` -} - -type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` -} - -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 - -type DataCustom struct { - ID int `orm:"column(id)"` - Boolean Boolean - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - Byte Byte - Rune Rune - Int Int - Int8 Int8 - Int16 Int16 - Int32 Int32 - Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 - Uint32 Uint32 - Uint64 Uint64 - Float32 Float32 - Float64 Float64 - Decimal Float64 `orm:"digits(8);decimals(4)"` -} - -// only for mysql -type UserBig struct { - ID uint64 `orm:"column(id)"` - Name string -} - -type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool -} - -func (u *User) TableIndex() [][]string { - return [][]string{ - {"Id", "UserName"}, - {"Id", "Created"}, - } -} - -func (u *User) TableUnique() [][]string { - return [][]string{ - {"UserName", "Email"}, - } -} - -func NewUser() *User { - obj := new(User) - return obj -} - -type Profile struct { - ID int `orm:"column(id)"` - Age int16 - Money float64 - User *User `orm:"reverse(one)" json:"-"` - BestPost *Post `orm:"rel(one);null"` -} - -func (u *Profile) TableName() string { - return "user_profile" -} - -func NewProfile() *Profile { - obj := new(Profile) - return obj -} - -type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` -} - -func (u *Post) TableIndex() [][]string { - return [][]string{ - {"Id", "Created"}, - } -} - -func NewPost() *Post { - obj := new(Post) - return obj -} - -type Tag struct { - ID int `orm:"column(id)"` - Name string `orm:"size(30)"` - BestPost *Post `orm:"rel(one);null"` - Posts []*Post `orm:"reverse(many)" json:"-"` -} - -func NewTag() *Tag { - obj := new(Tag) - return obj -} - -type PostTags struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk)"` - Tag *Tag `orm:"rel(fk)"` -} - -func (m *PostTags) TableName() string { - return "prefix_post_tags" -} - -type Comment struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk);column(post)"` - Content string `orm:"type(text)"` - Parent *Comment `orm:"null;rel(fk)"` - Created time.Time `orm:"auto_now_add"` -} - -func NewComment() *Comment { - obj := new(Comment) - return obj -} - -type Group struct { - ID int `orm:"column(gid);size(32)"` - Name string - Permissions []*Permission `orm:"reverse(many)" json:"-"` -} - -type Permission struct { - ID int `orm:"column(id)"` - Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` -} - -type GroupPermissions struct { - ID int `orm:"column(id)"` - Group *Group `orm:"rel(fk)"` - Permission *Permission `orm:"rel(fk)"` -} - -type ModelID struct { - ID int64 -} - -type ModelBase struct { - ModelID - - Created time.Time `orm:"auto_now_add;type(datetime)"` - Updated time.Time `orm:"auto_now;type(datetime)"` -} - -type InLine struct { - // Common Fields - ModelBase - - // Other Fields - Name string `orm:"unique"` - Email string -} - -func NewInLine() *InLine { - return new(InLine) -} - -type InLineOneToOne struct { - // Common Fields - ModelBase - - Note string - InLine *InLine `orm:"rel(fk);column(inline)"` -} - -func NewInLineOneToOne() *InLineOneToOne { - return new(InLineOneToOne) -} - -type IntegerPk struct { - ID int64 `orm:"pk"` - Value string -} - -type UintPk struct { - ID uint32 `orm:"pk"` - Name string -} - -type PtrPk struct { - ID *IntegerPk `orm:"pk;rel(one)"` - Positive bool -} - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -var ( - IsMysql = DBARGS.Driver == "mysql" - IsSqlite = DBARGS.Driver == "sqlite3" - IsPostgres = DBARGS.Driver == "postgres" - IsTidb = DBARGS.Driver == "tidb" -) - -var ( - dORM Ormer - dDbBaser dbBaser -) - -var ( - helpinfo = `need driver and source! - - Default DB Drivers. - - driver: url - mysql: https://github.com/go-sql-driver/mysql - sqlite3: https://github.com/mattn/go-sqlite3 - postgres: https://github.com/lib/pq - tidb: https://github.com/pingcap/tidb - - usage: - - go get -u github.com/astaxie/beego/orm - go get -u github.com/go-sql-driver/mysql - go get -u github.com/mattn/go-sqlite3 - go get -u github.com/lib/pq - go get -u github.com/pingcap/tidb - - #### MySQL - mysql -u root -e 'create database orm_test;' - export ORM_DRIVER=mysql - export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm - - - #### Sqlite3 - export ORM_DRIVER=sqlite3 - export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm - - - #### PostgreSQL - psql -c 'create database orm_test;' -U postgres - export ORM_DRIVER=postgres - export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm - - #### TiDB - export ORM_DRIVER=tidb - export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm - - ` -) - -func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() - - if DBARGS.Driver == "" || DBARGS.Source == "" { - fmt.Println(helpinfo) - os.Exit(2) - } - - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) - - alias := getDbAlias("default") - if alias.Driver == DRMySQL { - alias.Engine = "INNODB" - } - -} diff --git a/orm/orm_test.go b/orm/orm_test.go deleted file mode 100644 index bdb430b6..00000000 --- a/orm/orm_test.go +++ /dev/null @@ -1,2494 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -package orm - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - "testing" - "time" -) - -var _ = os.PathSeparator - -var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" -) - -type argAny []interface{} - -// get interface by index from interface slice -func (a argAny) Get(i int, args ...interface{}) (r interface{}) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { - if len(args) == 0 { - return false, fmt.Errorf("miss args") - } - b := args[0] - arg := argAny(args) - - switch v := a.(type) { - case reflect.Kind: - ok = reflect.ValueOf(b).Kind() == v - case time.Time: - if v2, vo := b.(time.Time); vo { - if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) - a = v.Format(format) - b = v2.Format(format) - ok = a == b - } else { - err = fmt.Errorf("compare datetime miss format") - goto wrongArg - } - } - default: - ok = ToStr(a) == ToStr(b) - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } - } - -wrongArg: - if err != nil { - return false, err - } - - return true, nil -} - -func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); !ok { - return err - } - return nil -} - -func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); !ok { - return err - } - return nil -} - -func getCaller(skip int) string { - pc, file, line, _ := runtime.Caller(skip) - fun := runtime.FuncForPC(pc) - _, fn := filepath.Split(file) - data, err := ioutil.ReadFile(file) - var codes []string - if err == nil { - lines := bytes.Split(data, []byte{'\n'}) - n := 10 - for i := 0; i < n; i++ { - o := line - n - if o < 0 { - continue - } - cur := o + i + 1 - flag := " " - if cur == line { - flag = ">>" - } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) - if code != "" { - codes = append(codes, code) - } - } - } - funName := fun.Name() - if i := strings.LastIndex(funName, "."); i > -1 { - funName = funName[i+1:] - } - return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) -} - -func throwFail(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.Fail() - } -} - -func throwFailNow(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.FailNow() - } -} - -func TestGetDB(t *testing.T) { - if db, err := GetDB(); err != nil { - throwFailNow(t, err) - } else { - err = db.Ping() - throwFailNow(t, err) - } -} - -func TestSyncDb(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - err := RunSyncdb("default", true, Debug) - throwFail(t, err) - - modelCache.clean() -} - -func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - BootStrap() - - dORM = NewOrm() - dDbBaser = getDbAlias("default").DbBaser -} - -func TestModelSyntax(t *testing.T) { - user := &User{} - ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) - throwFail(t, AssertIs(ok, true)) - - mi, ok = modelCache.get("user") - throwFail(t, AssertIs(ok, true)) - if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) - } -} - -var DataValues = map[string]interface{}{ - "Boolean": true, - "Char": "char", - "Text": "text", - "JSON": `{"name":"json"}`, - "Jsonb": `{"name": "jsonb"}`, - "Time": time.Now(), - "Date": time.Now(), - "DateTime": time.Now(), - "Byte": byte(1<<8 - 1), - "Rune": rune(1<<31 - 1), - "Int": int(1<<31 - 1), - "Int8": int8(1<<7 - 1), - "Int16": int16(1<<15 - 1), - "Int32": int32(1<<31 - 1), - "Int64": int64(1<<63 - 1), - "Uint": uint(1<<32 - 1), - "Uint8": uint8(1<<8 - 1), - "Uint16": uint16(1<<16 - 1), - "Uint32": uint32(1<<32 - 1), - "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported - "Float32": float32(100.1234), - "Float64": float64(100.1234), - "Decimal": float64(100.1234), -} - -func TestDataTypes(t *testing.T) { - d := Data{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - if name == "JSON" { - continue - } - e := ind.FieldByName(name) - e.Set(reflect.ValueOf(value)) - } - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = Data{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestNullDataTypes(t *testing.T) { - d := DataNull{} - - if IsPostgres { - // can removed when this fixed - // https://github.com/lib/pq/pull/125 - d.DateTime = time.Now() - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` - d = DataNull{ID: 1, JSON: data} - num, err := dORM.Update(&d) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - d = DataNull{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.JSON, data)) - - throwFail(t, AssertIs(d.NullBool.Valid, false)) - throwFail(t, AssertIs(d.NullString.Valid, false)) - throwFail(t, AssertIs(d.NullInt64.Valid, false)) - throwFail(t, AssertIs(d.NullFloat64.Valid, false)) - - throwFail(t, AssertIs(d.BooleanPtr, nil)) - throwFail(t, AssertIs(d.CharPtr, nil)) - throwFail(t, AssertIs(d.TextPtr, nil)) - throwFail(t, AssertIs(d.BytePtr, nil)) - throwFail(t, AssertIs(d.RunePtr, nil)) - throwFail(t, AssertIs(d.IntPtr, nil)) - throwFail(t, AssertIs(d.Int8Ptr, nil)) - throwFail(t, AssertIs(d.Int16Ptr, nil)) - throwFail(t, AssertIs(d.Int32Ptr, nil)) - throwFail(t, AssertIs(d.Int64Ptr, nil)) - throwFail(t, AssertIs(d.UintPtr, nil)) - throwFail(t, AssertIs(d.Uint8Ptr, nil)) - throwFail(t, AssertIs(d.Uint16Ptr, nil)) - throwFail(t, AssertIs(d.Uint32Ptr, nil)) - throwFail(t, AssertIs(d.Uint64Ptr, nil)) - throwFail(t, AssertIs(d.Float32Ptr, nil)) - throwFail(t, AssertIs(d.Float64Ptr, nil)) - throwFail(t, AssertIs(d.DecimalPtr, nil)) - throwFail(t, AssertIs(d.TimePtr, nil)) - throwFail(t, AssertIs(d.DatePtr, nil)) - throwFail(t, AssertIs(d.DateTimePtr, nil)) - - _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() - throwFail(t, err) - - d = DataNull{ID: 2} - err = dORM.Read(&d) - throwFail(t, err) - - booleanPtr := true - charPtr := string("test") - textPtr := string("test") - bytePtr := byte('t') - runePtr := rune('t') - intPtr := int(42) - int8Ptr := int8(42) - int16Ptr := int16(42) - int32Ptr := int32(42) - int64Ptr := int64(42) - uintPtr := uint(42) - uint8Ptr := uint8(42) - uint16Ptr := uint16(42) - uint32Ptr := uint32(42) - uint64Ptr := uint64(42) - float32Ptr := float32(42.0) - float64Ptr := float64(42.0) - decimalPtr := float64(42.0) - timePtr := time.Now() - datePtr := time.Now() - dateTimePtr := time.Now() - - d = DataNull{ - DateTime: time.Now(), - NullString: sql.NullString{String: "test", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - BooleanPtr: &booleanPtr, - CharPtr: &charPtr, - TextPtr: &textPtr, - BytePtr: &bytePtr, - RunePtr: &runePtr, - IntPtr: &intPtr, - Int8Ptr: &int8Ptr, - Int16Ptr: &int16Ptr, - Int32Ptr: &int32Ptr, - Int64Ptr: &int64Ptr, - UintPtr: &uintPtr, - Uint8Ptr: &uint8Ptr, - Uint16Ptr: &uint16Ptr, - Uint32Ptr: &uint32Ptr, - Uint64Ptr: &uint64Ptr, - Float32Ptr: &float32Ptr, - Float64Ptr: &float64Ptr, - DecimalPtr: &decimalPtr, - TimePtr: &timePtr, - DatePtr: &datePtr, - DateTimePtr: &dateTimePtr, - } - - id, err = dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - d = DataNull{ID: 3} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.NullBool.Valid, true)) - throwFail(t, AssertIs(d.NullBool.Bool, true)) - - throwFail(t, AssertIs(d.NullString.Valid, true)) - throwFail(t, AssertIs(d.NullString.String, "test")) - - throwFail(t, AssertIs(d.NullInt64.Valid, true)) - throwFail(t, AssertIs(d.NullInt64.Int64, 42)) - - throwFail(t, AssertIs(d.NullFloat64.Valid, true)) - throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) - - throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) - throwFail(t, AssertIs(*d.CharPtr, charPtr)) - throwFail(t, AssertIs(*d.TextPtr, textPtr)) - throwFail(t, AssertIs(*d.BytePtr, bytePtr)) - throwFail(t, AssertIs(*d.RunePtr, runePtr)) - throwFail(t, AssertIs(*d.IntPtr, intPtr)) - throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) - throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) - throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) - throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) - throwFail(t, AssertIs(*d.UintPtr, uintPtr)) - throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) - throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) - throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) - throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) - throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) - throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) - throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) - - // test support for pointer fields using RawSeter.QueryRows() - var dnList []*DataNull - Q := dDbBaser.TableQuote() - num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - equal := reflect.DeepEqual(*dnList[0], d) - throwFailNow(t, AssertIs(equal, true)) -} - -func TestDataCustomTypes(t *testing.T) { - d := DataCustom{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - e.Set(reflect.ValueOf(value).Convert(e.Type())) - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = DataCustom{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - vu := e.Interface() - value = reflect.ValueOf(value).Convert(e.Type()).Interface() - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestCRUD(t *testing.T) { - profile := NewProfile() - profile.Age = 30 - profile.Money = 1234.12 - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 3 - user.IsStaff = true - user.IsActive = true - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - u := &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - - throwFail(t, AssertIs(u.UserName, "slene")) - throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) - throwFail(t, AssertIs(u.Password, "pass")) - throwFail(t, AssertIs(u.Status, 3)) - throwFail(t, AssertIs(u.IsStaff, true)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) - - user.UserName = "astaxie" - user.Profile = profile - num, err := dORM.Update(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "astaxie")) - throwFail(t, AssertIs(u.Profile.ID, profile.ID)) - - u = &User{UserName: "astaxie", Password: "pass"} - err = dORM.Read(u, "UserName") - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, 1)) - - u.UserName = "QQ" - u.Password = "111" - num, err = dORM.Update(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "QQ")) - throwFail(t, AssertIs(u.Password, "pass")) - - num, err = dORM.Delete(profile) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - throwFail(t, AssertIs(true, u.Profile == nil)) - - num, err = dORM.Delete(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: 100} - err = dORM.Read(u) - throwFail(t, AssertIs(err, ErrNoRows)) - - ub := UserBig{} - ub.Name = "name" - id, err = dORM.Insert(&ub) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - ub = UserBig{ID: 1} - err = dORM.Read(&ub) - throwFail(t, err) - throwFail(t, AssertIs(ub.Name, "name")) - - num, err = dORM.Delete(&ub, "name") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertTestData(t *testing.T) { - var users []*User - - profile := NewProfile() - profile.Age = 28 - profile.Money = 1234.12 - - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 1 - user.IsStaff = false - user.IsActive = true - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - profile = NewProfile() - profile.Age = 30 - profile.Money = 4321.09 - - id, err = dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "astaxie" - user.Email = "astaxie@gmail.com" - user.Password = "password" - user.Status = 2 - user.IsStaff = true - user.IsActive = false - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "nobody" - user.Email = "nobody@gmail.com" - user.Password = "nobody" - user.Status = 3 - user.IsStaff = false - user.IsActive = false - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 4)) - - tags := []*Tag{ - {Name: "golang", BestPost: &Post{ID: 2}}, - {Name: "example"}, - {Name: "format"}, - {Name: "c++"}, - } - - posts := []*Post{ - {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. -This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. -With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, - {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. -The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, - } - - comments := []*Comment{ - {Post: posts[0], Content: "a comment"}, - {Post: posts[1], Content: "yes"}, - {Post: posts[1]}, - {Post: posts[1]}, - {Post: posts[2]}, - {Post: posts[2]}, - } - - for _, tag := range tags { - id, err := dORM.Insert(tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, post := range posts { - id, err := dORM.Insert(post) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(post.Tags) - if num > 0 { - nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - - for _, comment := range comments { - id, err := dORM.Insert(comment) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - permissions := []*Permission{ - {Name: "writePosts"}, - {Name: "readComments"}, - {Name: "readPosts"}, - } - - groups := []*Group{ - { - Name: "admins", - Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, - }, - { - Name: "users", - Permissions: []*Permission{permissions[1], permissions[2]}, - }, - } - - for _, permission := range permissions { - id, err := dORM.Insert(permission) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, group := range groups { - _, err := dORM.Insert(group) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(group.Permissions) - if num > 0 { - nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - -} - -func TestCustomField(t *testing.T) { - user := User{ID: 2} - err := dORM.Read(&user) - throwFailNow(t, err) - - user.Langs = append(user.Langs, "zh-CN", "en-US") - user.Extra.Name = "beego" - user.Extra.Data = "orm" - _, err = dORM.Update(&user, "Langs", "Extra") - throwFailNow(t, err) - - user = User{ID: 2} - err = dORM.Read(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(len(user.Langs), 2)) - throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) - throwFailNow(t, AssertIs(user.Langs[1], "en-US")) - - throwFailNow(t, AssertIs(user.Extra.Name, "beego")) - throwFailNow(t, AssertIs(user.Extra.Data, "orm")) -} - -func TestExpr(t *testing.T) { - user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") - num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("created", time.Now()).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() - // throwFail(t, err) - // throwFail(t, AssertIs(num, 3)) -} - -func TestOperators(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", String("slene")).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__iexact", "Slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__contains", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - var shouldNum int - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__contains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gt", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gte", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("status__lt", Uint(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__lte", Int(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("user_name__startswith", "s").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - if IsSqlite || IsTidb { - shouldNum = 1 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__startswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__istartswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__endswith", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__endswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__iendswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("status__in", 1, 2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__in", []int{1, 2}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - n1, n2 := 1, 2 - num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", 2, 3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", []int{2, 3}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("user_name", "= 'slene'").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.FilterRaw("status", "IN (1, 2)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSetCond(t *testing.T) { - cond := NewCondition() - cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - - qs := dORM.QueryTable("user") - num, err := qs.SetCond(cond1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond3 := cond.AndNotCond(cond.And("status__in", 1)) - num, err = qs.SetCond(cond3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond4).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond5).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) -} - -func TestLimit(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Limit(-1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - num, err = qs.Limit(-1, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Limit(0, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOffset(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOrderBy(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestAll(t *testing.T) { - var users []*User - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("Id").All(&users) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFail(t, AssertIs(users[0].UserName, "slene")) - throwFail(t, AssertIs(users[1].UserName, "astaxie")) - throwFail(t, AssertIs(users[2].UserName, "nobody")) - - var users2 []User - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").All(&users2) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(users2), 3)) - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - throwFailNow(t, AssertIs(users2[0].ID, 0)) - throwFailNow(t, AssertIs(users2[1].ID, 0)) - throwFailNow(t, AssertIs(users2[2].ID, 0)) - throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - var users3 []*User - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - throwFailNow(t, AssertIs(users3 == nil, false)) -} - -func TestOne(t *testing.T) { - var user User - qs := dORM.QueryTable("user") - err := qs.One(&user) - throwFail(t, err) - - user = User{} - err = qs.OrderBy("Id").Limit(1).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "slene")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - user = User{} - err = qs.OrderBy("-Id").Limit(100).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "nobody")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - err = qs.Filter("user_name", "nothing").One(&user) - throwFail(t, AssertIs(err, ErrNoRows)) - -} - -func TestValues(t *testing.T) { - var maps []Params - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[2]["Profile"], nil)) - } - - num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) - throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) - } - - num, err = qs.Filter("UserName", "slene").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestValuesList(t *testing.T) { - var list []ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").ValuesList(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) - } - - num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][0], "slene")) - throwFail(t, AssertIs(list[0][1], 28)) - throwFail(t, AssertIs(list[2][1], nil)) - } -} - -func TestValuesFlat(t *testing.T) { - var list ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "slene")) - throwFail(t, AssertIs(list[1], "astaxie")) - throwFail(t, AssertIs(list[2], "nobody")) - } -} - -func TestRelatedSel(t *testing.T) { - if IsTidb { - // Skip it. TiDB does not support relation now. - return - } - qs := dORM.QueryTable("user") - num, err := qs.Filter("profile__age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var user User - err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "slene").RelatedSel().One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(user.Profile, nil)) - - qs = dORM.QueryTable("user_profile") - num, err = qs.Filter("user__username", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var posts []*Post - qs = dORM.QueryTable("post") - num, err = qs.RelatedSel().All(&posts) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 4)) - - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) -} - -func TestReverseQuery(t *testing.T) { - var profile Profile - err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - profile = Profile{} - err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - var user User - err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - var posts []*Post - num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). - Filter("User__UserName", "slene").RelatedSel().All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].User == nil, false)) - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - - var tags []*Tag - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) - throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) -} - -func TestLoadRelated(t *testing.T) { - // load reverse foreign key - user := User{ID: 3} - - err := dORM.Read(&user) - throwFailNow(t, err) - - num, err := dORM.LoadRelated(&user, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - - num, err = dORM.LoadRelated(&user, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - // load reverse one to one - profile := Profile{ID: 3} - profile.BestPost = &Post{ID: 2} - num, err = dORM.Update(&profile, "BestPost") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - err = dORM.Read(&profile) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&profile, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&profile, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) - - // load rel one to one - err = dORM.Read(&user) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&user, "Profile") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - num, err = dORM.LoadRelated(&user, "Profile", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) - throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) - - post := Post{ID: 2} - - // load rel foreign key - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&post, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(post.User.Profile == nil, false)) - throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) - - // load rel m2m - post = Post{ID: 2} - - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "Tags") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - - num, err = dORM.LoadRelated(&post, "Tags", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) - - // load reverse m2m - tag := Tag{ID: 1} - - err = dORM.Read(&tag) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&tag, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - - num, err = dORM.LoadRelated(&tag, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) -} - -func TestQueryM2M(t *testing.T) { - post := Post{ID: 4} - m2m := dORM.QueryM2M(&post, "Tags") - - tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} - tag2 := &Tag{Name: "TestTag3"} - tag3 := []interface{}{&Tag{Name: "TestTag4"}} - - tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} - - for _, tag := range tags { - _, err := dORM.Insert(tag) - throwFailNow(t, err) - } - - num, err := m2m.Add(tag1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 5)) - - num, err = m2m.Remove(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - exist := m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - tag := Tag{Name: "test"} - _, err = dORM.Insert(&tag) - throwFailNow(t, err) - - m2m = dORM.QueryM2M(&tag, "Posts") - - post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} - post2 := &Post{Title: "TestPost3"} - post3 := []interface{}{&Post{Title: "TestPost4"}} - - posts := []interface{}{post1[0], post1[1], post2, post3[0]} - - for _, post := range posts { - p := post.(*Post) - p.User = &User{ID: 1} - _, err := dORM.Insert(post) - throwFailNow(t, err) - } - - num, err = m2m.Add(post1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - num, err = m2m.Remove(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - num, err = dORM.Delete(&tag) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) -} - -func TestQueryRelate(t *testing.T) { - // post := &Post{Id: 2} - - // qs := dORM.QueryRelate(post, "Tags") - // num, err := qs.Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - - // var tags []*Tag - // num, err = qs.All(&tags) - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - // throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) -} - -func TestPkManyRelated(t *testing.T) { - permission := &Permission{Name: "readPosts"} - err := dORM.Read(permission, "Name") - throwFailNow(t, err) - - var groups []*Group - qs := dORM.QueryTable("Group") - num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) -} - -func TestPrepareInsert(t *testing.T) { - qs := dORM.QueryTable("user") - i, err := qs.PrepareInsert() - throwFailNow(t, err) - - var user User - user.UserName = "testing1" - num, err := i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - user.UserName = "testing2" - num, err = i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - err = i.Close() - throwFail(t, err) - err = i.Close() - throwFail(t, AssertIs(err, ErrStmtClosed)) -} - -func TestRawExec(t *testing.T) { - Q := dDbBaser.TableQuote() - - query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) - res, err := dORM.Raw(query, "testing", "slene").Exec() - throwFail(t, err) - num, err := res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) - - res, err = dORM.Raw(query, "slene", "testing").Exec() - throwFail(t, err) - num, err = res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) -} - -func TestRawQueryRow(t *testing.T) { - var ( - Boolean bool - Char string - Text string - Time time.Time - Date time.Time - DateTime time.Time - Byte byte - Rune rune - Int int - Int8 int - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 - ) - - dataValues := make(map[string]interface{}, len(DataValues)) - - for k, v := range DataValues { - dataValues[strings.ToLower(k)] = v - } - - Q := dDbBaser.TableQuote() - - cols := []string{ - "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - var id int - values := []interface{}{ - &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, - &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, - } - err := dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - for i, col := range cols { - vu := values[i] - v := reflect.ValueOf(vu).Elem().Interface() - switch col { - case "id": - throwFail(t, AssertIs(id, 1)) - case "time": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) - case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) - default: - throwFail(t, AssertIs(v, dataValues[col])) - } - } - - var ( - uid int - status *int - pid *int - ) - - cols = []string{ - "id", "Status", "profile_id", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) - err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) - throwFail(t, err) - throwFail(t, AssertIs(uid, 4)) - throwFail(t, AssertIs(*status, 3)) - throwFail(t, AssertIs(pid, nil)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nd *DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - err = dORM.Raw(query, newId).QueryRow(&nd) - throwFailNow(t, err) - - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -// user_profile table -type userProfile struct { - User - Age int - Money float64 -} - -func TestQueryRows(t *testing.T) { - Q := dDbBaser.TableQuote() - - var datas []*Data - - query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas), 1)) - - ind := reflect.Indirect(reflect.ValueOf(datas[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var datas2 []Data - - query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err = dORM.Raw(query).QueryRows(&datas2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas2), 1)) - - ind = reflect.Indirect(reflect.ValueOf(datas2[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var ids []int - var usernames []string - query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &usernames) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(ids), 3)) - throwFailNow(t, AssertIs(ids[0], 2)) - throwFailNow(t, AssertIs(usernames[0], "slene")) - throwFailNow(t, AssertIs(ids[1], 3)) - throwFailNow(t, AssertIs(usernames[1], "astaxie")) - throwFailNow(t, AssertIs(ids[2], 4)) - throwFailNow(t, AssertIs(usernames[2], "nobody")) - - //test query rows by nested struct - var l []userProfile - query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&l) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(l), 2)) - throwFailNow(t, AssertIs(l[0].UserName, "slene")) - throwFailNow(t, AssertIs(l[0].Age, 28)) - throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(l[1].Age, 30)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nDataList []*DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - num, err = dORM.Raw(query, newId).QueryRows(&nDataList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - nd := nDataList[0] - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -func TestRawValues(t *testing.T) { - Q := dDbBaser.TableQuote() - - var maps []Params - query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) - num, err := dORM.Raw(query, 1).Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(maps[0]["user_name"], "slene")) - } - - var lists []ParamsList - num, err = dORM.Raw(query, 1).ValuesList(&lists) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(lists[0][0], "slene")) - } - - query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) - var list ParamsList - num, err = dORM.Raw(query).ValuesFlat(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "2")) - throwFail(t, AssertIs(list[1], "3")) - throwFail(t, AssertIs(list[2], nil)) - } -} - -func TestRawPrepare(t *testing.T) { - switch { - case IsMysql || IsSqlite: - - pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - throwFail(t, err) - if pre != nil { - r, err := pre.Exec("name1") - throwFail(t, err) - - tid, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(tid > 0, true)) - - r, err = pre.Exec("name2") - throwFail(t, err) - - id, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+1)) - - r, err = pre.Exec("name3") - throwFail(t, err) - - id, err = r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+2)) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - - case IsPostgres: - - pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() - throwFail(t, err) - if pre != nil { - _, err := pre.Exec("name1") - throwFail(t, err) - - _, err = pre.Exec("name2") - throwFail(t, err) - - _, err = pre.Exec("name3") - throwFail(t, err) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - if err == nil { - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - } - } -} - -func TestUpdate(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ - "is_staff": true, - "is_active": true, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - // with join - num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ - "is_staff": false, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColAdd, 100), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMinus, 50), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMultiply, 3), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColExcept, 5), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - user := User{UserName: "slene"} - err = dORM.Read(&user, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(user.Nums, 30)) -} - -func TestDelete(t *testing.T) { - qs := dORM.QueryTable("user_profile") - num, err := qs.Filter("user__user_name", "slene").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 6)) - - qs = dORM.QueryTable("post") - num, err = qs.Filter("Id", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - qs = dORM.QueryTable("comment") - num, err = qs.Filter("Post__User", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestTransaction(t *testing.T) { - // this test worked when database support transaction - - o := NewOrm() - err := o.Begin() - throwFail(t, err) - - var names = []string{"1", "2", "3"} - - var tag Tag - tag.Name = names[0] - id, err := o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - switch { - case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() - throwFail(t, err) - if err == nil { - id, err = res.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - } - - err = o.Rollback() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - err = o.Begin() - throwFail(t, err) - - tag.Name = "commit" - id, err = o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - o.Commit() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - -} - -func TestTransactionIsolationLevel(t *testing.T) { - // this test worked when database support transaction isolation level - if IsSqlite { - return - } - - o1 := NewOrm() - o2 := NewOrm() - - // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - - // o1 insert tag - var tag Tag - tag.Name = "test-transaction" - id, err := o1.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o1 commit - o1.Commit() - - // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o2 commit and query tag table, get the result - o2.Commit() - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestBeginTxWithContextCanceled(t *testing.T) { - o := NewOrm() - ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // cancel the context before commit to make it error - cancel() - err = o.Commit() - throwFail(t, AssertIs(err, context.Canceled)) -} - -func TestReadOrCreate(t *testing.T) { - u := &User{ - UserName: "Kyle", - Email: "kylemcc@gmail.com", - Password: "other_pass", - Status: 7, - IsStaff: false, - IsActive: true, - } - - created, pk, err := dORM.ReadOrCreate(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.ID, pk)) - throwFail(t, AssertIs(u.UserName, "Kyle")) - throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) - throwFail(t, AssertIs(u.Password, "other_pass")) - throwFail(t, AssertIs(u.Status, 7)) - throwFail(t, AssertIs(u.IsStaff, false)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) - - nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} - created, pk, err = dORM.ReadOrCreate(nu, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.UserName, u.UserName)) - throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above - throwFail(t, AssertIs(nu.Password, u.Password)) - throwFail(t, AssertIs(nu.Status, u.Status)) - throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) - throwFail(t, AssertIs(nu.IsActive, u.IsActive)) - - dORM.Delete(u) -} - -func TestInLine(t *testing.T) { - name := "inline" - email := "hello@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - il := NewInLine() - il.ID = 1 - err = dORM.Read(il) - throwFail(t, err) - - throwFail(t, AssertIs(il.Name, name)) - throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) -} - -func TestInLineOneToOne(t *testing.T) { - name := "121" - email := "121@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - note := "one2one" - il121 := NewInLineOneToOne() - il121.Note = note - il121.InLine = inline - _, err = dORM.Insert(il121) - throwFail(t, err) - throwFail(t, AssertIs(il121.ID, 1)) - - il := NewInLineOneToOne() - err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) - - throwFail(t, err) - throwFail(t, AssertIs(il.Note, note)) - throwFail(t, AssertIs(il.InLine.ID, id)) - throwFail(t, AssertIs(il.InLine.Name, name)) - throwFail(t, AssertIs(il.InLine.Email, email)) - - rinline := NewInLine() - err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) - - throwFail(t, err) - throwFail(t, AssertIs(rinline.ID, id)) - throwFail(t, AssertIs(rinline.Name, name)) - throwFail(t, AssertIs(rinline.Email, email)) -} - -func TestIntegerPk(t *testing.T) { - its := []IntegerPk{ - {ID: math.MinInt64, Value: "-"}, - {ID: 0, Value: "0"}, - {ID: math.MaxInt64, Value: "+"}, - } - - num, err := dORM.InsertMulti(len(its), its) - throwFail(t, err) - throwFail(t, AssertIs(num, len(its))) - - for _, intPk := range its { - out := IntegerPk{ID: intPk.ID} - err = dORM.Read(&out) - throwFail(t, err) - throwFail(t, AssertIs(out.Value, intPk.Value)) - } - - num, err = dORM.InsertMulti(1, []*IntegerPk{{ - ID: 1, Value: "ok", - }}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertAuto(t *testing.T) { - u := &User{ - UserName: "autoPre", - Email: "autoPre@gmail.com", - } - - id, err := dORM.Insert(u) - throwFail(t, err) - - id += 100 - su := &User{ - ID: int(id), - UserName: "auto", - Email: "auto@gmail.com", - } - - nid, err := dORM.Insert(su) - throwFail(t, err) - throwFail(t, AssertIs(nid, id)) - - users := []User{ - {ID: int(id + 100), UserName: "auto_100"}, - {ID: int(id + 110), UserName: "auto_110"}, - {ID: int(id + 120), UserName: "auto_120"}, - } - num, err := dORM.InsertMulti(100, users) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - u = &User{ - UserName: "auto_121", - } - - nid, err = dORM.Insert(u) - throwFail(t, err) - throwFail(t, AssertIs(nid, id+120+1)) -} - -func TestUintPk(t *testing.T) { - name := "go" - u := &UintPk{ - ID: 8, - Name: name, - } - - created, _, err := dORM.ReadOrCreate(u, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.Name, name)) - - nu := &UintPk{ID: 8} - created, pk, err := dORM.ReadOrCreate(nu, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.Name, name)) - - dORM.Delete(u) -} - -func TestPtrPk(t *testing.T) { - parent := &IntegerPk{ID: 10, Value: "10"} - - id, _ := dORM.Insert(parent) - if !IsMysql { - // MySql does not support last_insert_id in this case: see #2382 - throwFail(t, AssertIs(id, 10)) - } - - ptr := PtrPk{ID: parent, Positive: true} - num, err := dORM.InsertMulti(2, []PtrPk{ptr}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(ptr.ID, parent)) - - nptr := &PtrPk{ID: parent} - created, pk, err := dORM.ReadOrCreate(nptr, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, true)) - - nptr = &PtrPk{Positive: true} - created, pk, err = dORM.ReadOrCreate(nptr, "Positive") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - - nptr.Positive = false - num, err = dORM.Update(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, false)) - - num, err = dORM.Delete(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSnake(t *testing.T) { - cases := map[string]string{ - "i": "i", - "I": "i", - "iD": "i_d", - "ID": "i_d", - "NO": "n_o", - "NOO": "n_o_o", - "NOOooOOoo": "n_o_ooo_o_ooo", - "OrderNO": "order_n_o", - "tagName": "tag_name", - "tag_Name": "tag__name", - "tag_name": "tag_name", - "_tag_name": "_tag_name", - "tag_666name": "tag_666name", - "tag_666Name": "tag_666_name", - } - for name, want := range cases { - got := snakeString(name) - throwFail(t, AssertIs(got, want)) - } -} - -func TestIgnoreCaseTag(t *testing.T) { - type testTagModel struct { - ID int `orm:"pk"` - NOO string `orm:"column(n)"` - Name01 string `orm:"NULL"` - Name02 string `orm:"COLUMN(Name)"` - Name03 string `orm:"Column(name)"` - } - modelCache.clean() - RegisterModel(&testTagModel{}) - info, ok := modelCache.get("test_tag_model") - throwFail(t, AssertIs(ok, true)) - throwFail(t, AssertNot(info, nil)) - if t == nil { - return - } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) -} - -func TestInsertOrUpdate(t *testing.T) { - RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} - dORM.Insert(&user) - test := User{UserName: "unique_username133"} - fmt.Println(dORM.Driver().Name()) - if dORM.Driver().Name() == "sqlite3" { - fmt.Println("sqlite3 is nonsupport") - return - } - //test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) - } - //test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) - } - - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { - return - } - //test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) - } - //test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) - } - //test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) - } - //test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) - } -} diff --git a/orm/utils_test.go b/orm/utils_test.go deleted file mode 100644 index 7d94cada..00000000 --- a/orm/utils_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "testing" -) - -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} - - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } - - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeString(t *testing.T) { - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeStringWithAcronym(t *testing.T) { - camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index e3a635f2..79e926d3 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -293,7 +293,7 @@ type Post struct { Content string `orm:"type(text)"` Created time.Time `orm:"auto_now_add"` Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { @@ -351,7 +351,7 @@ type Group struct { type Permission struct { ID int `orm:"column(id)"` Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.GroupPermissions)"` } type GroupPermissions struct { @@ -446,7 +446,7 @@ var ( usage: - go get -u github.com/astaxie/beego/orm + go get -u github.com/astaxie/beego/pkg/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq @@ -456,25 +456,25 @@ var ( mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pgk/orm ` ) @@ -487,7 +487,11 @@ func init() { os.Exit(2) } - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + if err != nil{ + panic(fmt.Sprintf("can not register database: %v", err)) + } alias := getDbAlias("default") if alias.Driver == DRMySQL { diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 0551b1cd..bd2cb783 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/astaxie/beego/orm" +// "github.com/astaxie/beego/pkg/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index bdb430b6..eac7b33a 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -30,6 +30,8 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) var _ = os.PathSeparator @@ -141,6 +143,7 @@ func getCaller(skip int) string { return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } +// Deprecated: Using stretchr/testify/assert func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -455,9 +458,11 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() + assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) + assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) + assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) // test support for pointer fields using RawSeter.QueryRows() var dnList []*DataNull @@ -532,8 +537,9 @@ func TestCRUD(t *testing.T) { throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) user.UserName = "astaxie" user.Profile = profile @@ -1793,7 +1799,7 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(usernames[2], "nobody")) - //test query rows by nested struct + // test query rows by nested struct var l []userProfile query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) num, err = dORM.Raw(query).QueryRows(&l) @@ -2413,7 +2419,7 @@ func TestInsertOrUpdate(t *testing.T) { fmt.Println("sqlite3 is nonsupport") return } - //test1 + // test1 _, err := dORM.InsertOrUpdate(&user1, "user_name") if err != nil { fmt.Println(err) @@ -2425,7 +2431,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(user1.Status, test.Status)) } - //test2 + // test2 _, err = dORM.InsertOrUpdate(&user2, "user_name") if err != nil { fmt.Println(err) @@ -2439,11 +2445,11 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) } - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values if IsPostgres { return } - //test3 + + // test3 + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") if err != nil { fmt.Println(err) @@ -2455,7 +2461,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(user2.Status+1, test.Status)) } - //test4 - + // test4 - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") if err != nil { fmt.Println(err) @@ -2467,7 +2473,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) } - //test5 * + // test5 * _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") if err != nil { fmt.Println(err) @@ -2479,7 +2485,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) } - //test6 / + // test6 / _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") if err != nil { fmt.Println(err) From 7258ef113a0daa8478dc86afd41859dc8e659239 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 19 Jul 2020 14:34:57 +0000 Subject: [PATCH 029/207] Store nearest error info --- toolbox/task.go | 12 +++++++++--- toolbox/task_test.go | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/toolbox/task.go b/toolbox/task.go index d2a94ba9..c902fdfc 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -102,6 +102,8 @@ type taskerr struct { } // Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList type Task struct { Taskname string Spec *Schedule @@ -111,6 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func @@ -119,8 +122,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { task := &Task{ Taskname: tname, DoFunc: f, + // Make configurable ErrLimit: 100, SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), } task.SetCron(spec) return task @@ -144,9 +150,9 @@ func (t *Task) GetStatus() string { func (t *Task) Run() error { err := t.DoFunc() if err != nil { - if t.ErrLimit > 0 && t.ErrLimit > len(t.Errlist) { - t.Errlist = append(t.Errlist, &taskerr{t: t.Next, errinfo: err.Error()}) - } + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ } return err } diff --git a/toolbox/task_test.go b/toolbox/task_test.go index 596bc9c5..3a4cce2f 100644 --- a/toolbox/task_test.go +++ b/toolbox/task_test.go @@ -15,10 +15,13 @@ package toolbox import ( + "errors" "fmt" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestParse(t *testing.T) { @@ -53,6 +56,25 @@ func TestSpec(t *testing.T) { } } +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func() error { + cnt ++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200 ; i ++ { + e := tk.Run() + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { From 32da446eb1d8785b70037d648ca74261cb20fd2f Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 19 Jul 2020 23:46:42 +0800 Subject: [PATCH 030/207] refactor orm --- pkg/orm/db_alias.go | 53 +++++++ pkg/orm/orm.go | 307 +++++++++++++++++++++++++--------------- pkg/orm/orm_object.go | 4 +- pkg/orm/orm_querym2m.go | 2 +- pkg/orm/orm_queryset.go | 4 +- pkg/orm/orm_raw.go | 4 +- pkg/orm/orm_test.go | 34 ++--- pkg/orm/types.go | 162 +++++++++++++-------- 8 files changed, 370 insertions(+), 200 deletions(-) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index bf6c350c..b2a72f56 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -111,6 +111,9 @@ type DB struct { stmtDecorators *lru.Cache } +var _ dbQuerier = new(DB) +var _ txer = new(DB) + func (d *DB) Begin() (*sql.Tx, error) { return d.DB.Begin() } @@ -220,6 +223,56 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac return stmt.QueryRowContext(ctx, args) } +type TxDB struct { + tx *sql.Tx +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Commit() error { + return t.tx.Commit() +} + +func (t *TxDB) Rollback() error { + return t.tx.Rollback() +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { + return t.PrepareContext(context.Background(),query) +} + +func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return t.tx.PrepareContext(ctx, query) +} + +func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.ExecContext(context.Background(), query, args...) +} + +func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return t.QueryContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { + return t.QueryRowContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + type alias struct { Name string Driver DriverType diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index bd2cb783..3db75751 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -62,6 +62,8 @@ import ( "reflect" "sync" "time" + + "github.com/astaxie/beego/logs" ) // DebugQueries define the debug @@ -76,8 +78,7 @@ var ( DefaultRowsLimit = -1 DefaultRelsDepth = 2 DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") + ErrTxDone = errors.New(" transaction already done") ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") @@ -91,16 +92,16 @@ type Params map[string]interface{} // ParamsList stores paramslist type ParamsList []interface{} -type orm struct { +type ormBase struct { alias *alias db dbQuerier - isTx bool } -var _ Ormer = new(orm) +var _ DQL = new(ormBase) +var _ DML = new(ormBase) // get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() @@ -115,7 +116,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect } // get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { +func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) @@ -124,33 +125,42 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } // read data to model -func (o *orm) Read(md interface{}, cols ...string) error { +func (o *ormBase) Read(md interface{}, cols ...string) error { + return o.ReadWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { +func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { + return o.ReadForUpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { +func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} +func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create - id, err := o.Insert(md) - return (err == nil), id, err + id, err := o.InsertWithCtx(ctx, md) + return err == nil, id, err } id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { id = int64(vid.Uint()) } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) } else { id = vid.Int() } @@ -159,7 +169,10 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i } // insert model data to database -func (o *orm) Insert(md interface{}) (int64, error) { +func (o *ormBase) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} +func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { @@ -172,7 +185,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { } // set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { +func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) @@ -183,7 +196,10 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } // insert some models to database -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { +func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { + return o.InsertMultiWithCtx(context.Background(), bulk, mds) +} +func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var cnt int64 sind := reflect.Indirect(reflect.ValueOf(mds)) @@ -218,7 +234,10 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { } // InsertOrUpdate data to database -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { +func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { + return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) +} +func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { @@ -232,14 +251,20 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // update model to database. // cols set the columns those want to update. -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { + return o.UpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) } // delete model in database // cols shows the delete conditions values read from. default is pk -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { + return o.DeleteWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) if err != nil { @@ -252,7 +277,10 @@ func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { } // create a models to models queryer -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { +func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { + return o.QueryM2MWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -274,7 +302,10 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + return o.LoadRelatedWithCtx(context.Background(), md, name, args...) +} +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) @@ -341,14 +372,17 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int // qs := orm.QueryRelated(post,"Tag") // qs.All(&[]*Tag{}) // -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { +func (o *ormBase) QueryRelated(md interface{}, name string) QuerySeter { + return o.QueryRelatedWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryRelatedWithCtx(ctx context.Context, md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } // get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -380,7 +414,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, } // get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: default: @@ -401,7 +435,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS } // get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: default: @@ -423,7 +457,10 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { +func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +} +func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -442,11 +479,136 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { return } -// switch to another registered database driver by given name. -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) +// return a raw query seter for raw sql string. +func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { + return o.RawWithCtx(context.Background(), query, args...) +} +func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *ormBase) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *ormBase) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats } + return nil +} + +type orm struct { + ormBase +} + +var _ Ormer = new(orm) + +func (o *orm) Begin() (TxOrmer, error) { + return o.BeginWithCtx(context.Background()) +} + +func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(ctx, nil) +} + +func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + _txOrm := &txOrm{ + ormBase: ormBase{ + alias: o.alias, + db: &TxDB{tx: tx}, + }, + isClosed: false, + } + + var taskTxOrm TxOrmer = _txOrm + return taskTxOrm, nil +} + +func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtx(context.Background(), task) +} + +func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + panicked := true + defer func() { + if panicked || err != nil { + e := _txOrm.Rollback() + logs.Error("rollback transaction failed: %v", e) + } else { + e := _txOrm.Commit() + logs.Error("commit transaction failed: %v", e) + } + }() + + var taskTxOrm = _txOrm + err = task(taskTxOrm) + panicked = false + return err +} + +type txOrm struct { + ormBase + isClosed bool + closeMutex sync.Mutex +} + +var _ TxOrmer = new(txOrm) + +func (t *txOrm) Commit() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Commit() +} + +func (t *txOrm) Rollback() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Rollback() +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + name := `default` if al, ok := dataBaseCache.get(name); ok { o.alias = al if Debug { @@ -455,92 +617,9 @@ func (o *orm) Using(name string) error { o.db = al.DB } } else { - return fmt.Errorf(" unknown db alias name `%s`", name) + panic(fmt.Errorf(" unknown db alias name `%s`", name)) } - return nil -} -// begin transaction -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.DB.Stats() - return &stats - } - return nil -} - -// NewOrm create new orm -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } return o } diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go index de3181ce..6f9798d3 100644 --- a/pkg/orm/orm_object.go +++ b/pkg/orm/orm_object.go @@ -22,7 +22,7 @@ import ( // an insert queryer struct type insertSet struct { mi *modelInfo - orm *orm + orm *ormBase stmt stmtQuerier closed bool } @@ -70,7 +70,7 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { +func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go index 6a270a0d..17e1b5d1 100644 --- a/pkg/orm/orm_querym2m.go +++ b/pkg/orm/orm_querym2m.go @@ -129,7 +129,7 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md qm2m.mi = mi diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 878b836b..83168de7 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -72,7 +72,7 @@ type querySet struct { orders []string distinct bool forupdate bool - orm *orm + orm *ormBase ctx context.Context forContext bool } @@ -292,7 +292,7 @@ func (o querySet) WithContext(ctx context.Context) QuerySeter { } // create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { +func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) o.mi = mi o.orm = orm diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 3325a7ea..5e05eded 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -63,7 +63,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { type rawSet struct { query string args []interface{} - orm *orm + orm *ormBase } var _ RawSeter = new(rawSet) @@ -858,7 +858,7 @@ func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { +func newRawSet(orm *ormBase, query string, args []interface{}) RawSeter { o := new(rawSet) o.query = query o.args = args diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index eac7b33a..b7b2d9a7 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2026,24 +2026,24 @@ func TestTransaction(t *testing.T) { // this test worked when database support transaction o := NewOrm() - err := o.Begin() + to, err := o.Begin() throwFail(t, err) var names = []string{"1", "2", "3"} var tag Tag tag.Name = names[0] - id, err := o.Insert(&tag) + id, err := to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + num, err := to.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) throwFail(t, AssertIs(num, 1)) switch { case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { id, err = res.LastInsertId() @@ -2052,22 +2052,22 @@ func TestTransaction(t *testing.T) { } } - err = o.Rollback() + err = to.Rollback() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) - err = o.Begin() + to, err = o.Begin() throwFail(t, err) tag.Name = "commit" - id, err = o.Insert(&tag) + id, err = to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - o.Commit() + to.Commit() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() @@ -2086,15 +2086,15 @@ func TestTransactionIsolationLevel(t *testing.T) { o2 := NewOrm() // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to1, err := o1.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to2, err := o2.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) // o1 insert tag var tag Tag tag.Name = "test-transaction" - id, err := o1.Insert(&tag) + id, err := to1.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) @@ -2104,15 +2104,15 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(num, 0)) // o1 commit - o1.Commit() + to1.Commit() // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err = to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) // o2 commit and query tag table, get the result - o2.Commit() + to2.Commit() num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -2125,14 +2125,14 @@ func TestTransactionIsolationLevel(t *testing.T) { func TestBeginTxWithContextCanceled(t *testing.T) { o := NewOrm() ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) + to, _ := o.BeginWithCtx(ctx) + id, err := to.Insert(&Tag{Name: "test-context"}) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) // cancel the context before commit to make it error cancel() - err = o.Commit() + err = to.Commit() throwFail(t, AssertIs(err, context.Canceled)) } diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 2fd10774..b7a38826 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -35,35 +35,43 @@ type Fielder interface { RawValue() interface{} } -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) +type TxBeginner interface { + //self control transaction + Begin() (TxOrmer, error) + BeginWithCtx(ctx context.Context) (TxOrmer, error) + BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) + BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) + + //closure control transaction + DoTx(task func(txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error +} + +type TxCommitter interface { + Commit() error + Rollback() error +} + +//Data Manipulation Language +type DML interface { // insert model data to database // for example: // user := new(User) // id, err = Ormer.Insert(user) // user must be a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) + Insert(md interface{}) (int64, error) + InsertWithCtx(ctx context.Context, md interface{}) (int64, error) // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // if colu type is integer : can use(+-*/), string : convert(colu,"value") // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") // if colu type is integer : can use(+-*/), string : colu || "value" InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) // insert some models to database InsertMulti(bulk int, mds interface{}) (int64, error) + InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) // update model to database. // cols set the columns those want to update. // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns @@ -74,63 +82,93 @@ type Ormer interface { // user.Extra.Data = "orm" // num, err = Ormer.Update(&user, "Langs", "Extra") Update(md interface{}, cols ...string) (int64, error) + UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) // delete model in database Delete(md interface{}, cols ...string) (int64, error) + DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) + + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter +} + +// Data Query Language +type DQL interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate( md interface{}, cols ...string) error + ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) + // load related models to md model. // args are limit, offset int and order string. // // example: // Ormer.LoadRelated(post,"Tags") // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated( md interface{}, name string, args ...interface{}) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer + QueryM2M( md interface{}, name string) QueryM2Mer + QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver + // Using(name string) error + DBStats() *sql.DBStats } +type DriverGetter interface { + Driver() Driver +} + +type Ormer interface { + DQL + DML + DriverGetter + TxBeginner +} + +type TxOrmer interface { + DQL + DML + DriverGetter + TxCommitter +} + // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) @@ -229,7 +267,7 @@ type QuerySeter interface { // }) // user slene's name will change to slene2 Update(values Params) (int64, error) // delete from table - //for example: + // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) @@ -314,8 +352,8 @@ type QueryM2Mer interface { // remove models following the origin model relationship // only delete rows from m2m table // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) + // tag3 := &Tag{Id:5,Name: "TestTag3"} + // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool @@ -337,10 +375,10 @@ type RawPreparer interface { // sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) // rs := Ormer.Raw(sql, 1) type RawSeter interface { - //execute sql and get result + // execute sql and get result Exec() (sql.Result, error) - //query data and map to container - //for example: + // query data and map to container + // for example: // var name string // var id int // rs.QueryRow(&id,&name) // id==2 name=="slene" @@ -396,11 +434,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) + // QueryContext(args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier From aefe21b63af934b346960b4a99d560936e2b0b49 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:25:27 +0800 Subject: [PATCH 031/207] complete error log --- pkg/orm/orm.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 3db75751..24632270 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -558,10 +558,14 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task defer func() { if panicked || err != nil { e := _txOrm.Rollback() - logs.Error("rollback transaction failed: %v", e) + if e != nil { + logs.Error("rollback transaction failed: %v,%v", e, panicked) + } } else { e := _txOrm.Commit() - logs.Error("commit transaction failed: %v", e) + if e != nil { + logs.Error("commit transaction failed: %v,%v", e, panicked) + } } }() From 4aad313de7fbf4da9fd74e89d1e722f2702a29b2 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:34:58 +0800 Subject: [PATCH 032/207] do not judge tx status in txOrm --- pkg/orm/orm.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 24632270..8ef761f4 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -530,7 +530,6 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO alias: o.alias, db: &TxDB{tx: tx}, }, - isClosed: false, } var taskTxOrm TxOrmer = _txOrm @@ -577,33 +576,15 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task type txOrm struct { ormBase - isClosed bool - closeMutex sync.Mutex } var _ TxOrmer = new(txOrm) func (t *txOrm) Commit() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Commit() } func (t *txOrm) Rollback() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Rollback() } From b6f7d30f9f6192d6046e6b9db0f6a4fd261a829e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 19:10:57 +0800 Subject: [PATCH 033/207] fix unit test --- pkg/orm/orm_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index b7b2d9a7..54ecc0fd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2099,7 +2099,7 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(id > 0, true)) // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err := to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) From 44460bc4570b58cefd7b4b1c65e8a1610ceefcbc Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 20 Jul 2020 15:23:17 +0000 Subject: [PATCH 034/207] Refactor RegisterDatabase --- orm/db_alias.go | 87 ++++++++----------------------------- orm/orm_alias_adapt_test.go | 46 ++++++++++++++++++++ pkg/common/kv.go | 69 +++++++++++++++++++++++++++++ pkg/common/kv_test.go | 40 +++++++++++++++++ pkg/orm/constant.go | 21 +++++++++ pkg/orm/db_alias.go | 64 ++++++++++++++------------- pkg/orm/db_alias_test.go | 44 +++++++++++++++++++ pkg/orm/models_test.go | 7 ++- 8 files changed, 279 insertions(+), 99 deletions(-) create mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/common/kv.go create mode 100644 pkg/common/kv_test.go create mode 100644 pkg/orm/constant.go create mode 100644 pkg/orm/db_alias_test.go diff --git a/orm/db_alias.go b/orm/db_alias.go index bf6c350c..a84070b4 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,16 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" + orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -63,7 +68,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -119,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -289,82 +294,26 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - // AddAliasWthDB add a aliasName for the drivename +// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err + return orm2.AddAliasWthDB(aliasName, driverName, db) } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - + kvs := make([]common.KV, 0, 2) for i, v := range params { switch i { case 0: - SetMaxIdleConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) case 1: - SetMaxOpenConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) + case 2: + kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) } } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err + return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -424,7 +373,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -444,7 +393,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go new file mode 100644 index 00000000..d7724527 --- /dev/null +++ b/orm/orm_alias_adapt_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" +) + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) + assert.Nil(t, err) +} diff --git a/pkg/common/kv.go b/pkg/common/kv.go new file mode 100644 index 00000000..508e6b5c --- /dev/null +++ b/pkg/common/kv.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +// KV is common structure to store key-value data. +// when you need something like Pair, you can use this +type KV struct { + Key interface{} + Value interface{} +} + +// KVs will store KV collection as map +type KVs struct { + kvs map[interface{}]interface{} +} + +// GetValueOr check whether this contains the key, +// if the key not found, the default value will be return +func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { + v, ok := kvs.kvs[key] + if ok { + return v + } + return defValue +} + +// Contains will check whether contains the key +func (kvs *KVs) Contains(key interface{}) bool { + _, ok := kvs.kvs[key] + return ok +} + +// IfContains is a functional API that if the key is in KVs, the action will be invoked +func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { + v, ok := kvs.kvs[key] + if ok { + action(v) + } + return kvs +} + +// Put store the value +func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { + kvs.kvs[key] = value + return kvs +} + +// NewKVs will create the *KVs instance +func NewKVs(kvs ...KV) *KVs { + res := &KVs{ + kvs: make(map[interface{}]interface{}, len(kvs)), + } + for _, kv := range kvs { + res.kvs[kv.Key] = kv.Value + } + return res +} diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go new file mode 100644 index 00000000..ed7dc7ef --- /dev/null +++ b/pkg/common/kv_test.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKVs(t *testing.T) { + key := "my-key" + kvs := NewKVs(KV{ + Key: key, + Value: 12, + }) + + assert.True(t, kvs.Contains(key)) + + kvs.IfContains(key, func(value interface{}) { + kvs.Put("my-key1", "") + }) + + assert.True(t, kvs.Contains("my-key1")) + + v := kvs.GetValueOr(key, 13) + assert.Equal(t, 12, v) +} diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go new file mode 100644 index 00000000..14f40a7b --- /dev/null +++ b/pkg/orm/constant.go @@ -0,0 +1,21 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +const ( + MaxIdleConnsKey = "MaxIdleConns" + MaxOpenConnsKey = "MaxOpenConns" + ConnMaxLifetimeKey = "ConnMaxLifetime" +) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index b2a72f56..90c5de3c 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,10 +18,12 @@ import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" - "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" ) // DriverType database driver constant int. @@ -63,7 +65,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -122,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -274,16 +276,17 @@ func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interf } type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *DB - DbBaser dbBaser - TZ *time.Location - Engine string + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string } func detectTZ(al *alias) { @@ -378,13 +381,15 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { var ( err error db *sql.DB al *alias ) + kvs := common.NewKVs(params...) + db, err = sql.Open(driverName, dataSource) if err != nil { err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) @@ -400,14 +405,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) e detectTZ(al) - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) end: if err != nil { @@ -454,10 +458,12 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } +} + +func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { + al := getDbAlias(aliasName) + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) } // GetDB Get *sql.DB from registered database by db alias name. @@ -477,7 +483,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -497,7 +503,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go new file mode 100644 index 00000000..a0cdcd44 --- /dev/null +++ b/pkg/orm/db_alias_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/common" +) + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxIdleConnsKey, + Value: 20, + }, common.KV{ + Key: MaxOpenConnsKey, + Value: 300, + }, common.KV{ + Key: ConnMaxLifetimeKey, + Value: time.Minute, + }) + assert.Nil(t, err) + + al := getDbAlias("test-params") + assert.NotNil(t, al) + assert.Equal(t, al.MaxIdleConns, 20) + assert.Equal(t, al.MaxOpenConns, 300) + assert.Equal(t, al.ConnMaxLifetime, time.Minute) +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 79e926d3..f14ee9cf 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -27,6 +27,8 @@ import ( _ "github.com/mattn/go-sqlite3" // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" + + "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -487,7 +489,10 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ + Key:MaxIdleConnsKey, + Value:20, + }) if err != nil{ panic(fmt.Sprintf("can not register database: %v", err)) From a66b9950e7e1db8d64140f5fa4a6559b04d3a207 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 20 Jul 2020 21:21:59 +0100 Subject: [PATCH 035/207] Add Content-length field for logging --- router.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router.go b/router.go index a993a1af..6a8ac6f7 100644 --- a/router.go +++ b/router.go @@ -1046,7 +1046,7 @@ func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { HTTPReferrer: r.Header.Get("Referer"), HTTPUserAgent: r.Header.Get("User-Agent"), RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: 0, // @todo this one is missing! + BodyBytesSent: r.ContentLength, } logs.AccessLog(record, BConfig.Log.AccessLogsFormat) } From 9c51952db485cb32a6658df173f622f6930199cd Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 22:50:08 +0800 Subject: [PATCH 036/207] Move package --- orm/cmd_utils.go | 10 +- orm/db_alias.go | 87 +- orm/orm_alias_adapt_test.go | 46 - pkg/admin.go | 458 +++++++ pkg/admin_test.go | 239 ++++ pkg/adminui.go | 356 ++++++ pkg/app.go | 496 ++++++++ pkg/beego.go | 123 ++ pkg/build_info.go | 27 + pkg/cache/README.md | 59 + pkg/cache/cache.go | 103 ++ pkg/cache/cache_test.go | 191 +++ pkg/cache/conv.go | 100 ++ pkg/cache/conv_test.go | 143 +++ pkg/cache/file.go | 258 ++++ pkg/cache/memcache/memcache.go | 188 +++ pkg/cache/memcache/memcache_test.go | 108 ++ pkg/cache/memory.go | 256 ++++ pkg/cache/redis/redis.go | 272 +++++ pkg/cache/redis/redis_test.go | 144 +++ pkg/cache/ssdb/ssdb.go | 231 ++++ pkg/cache/ssdb/ssdb_test.go | 104 ++ pkg/config.go | 524 ++++++++ pkg/config/config.go | 242 ++++ pkg/config/config_test.go | 55 + pkg/config/env/env.go | 87 ++ pkg/config/env/env_test.go | 75 ++ pkg/config/fake.go | 134 +++ pkg/config/ini.go | 504 ++++++++ pkg/config/ini_test.go | 190 +++ pkg/config/json.go | 269 +++++ pkg/config/json_test.go | 222 ++++ pkg/config/xml/xml.go | 228 ++++ pkg/config/xml/xml_test.go | 125 ++ pkg/config/yaml/yaml.go | 316 +++++ pkg/config/yaml/yaml_test.go | 115 ++ pkg/config_test.go | 146 +++ pkg/context/acceptencoder.go | 232 ++++ pkg/context/acceptencoder_test.go | 59 + pkg/context/context.go | 263 +++++ pkg/context/context_test.go | 47 + pkg/context/input.go | 689 +++++++++++ pkg/context/input_test.go | 217 ++++ pkg/context/output.go | 408 +++++++ pkg/context/param/conv.go | 78 ++ pkg/context/param/methodparams.go | 69 ++ pkg/context/param/options.go | 37 + pkg/context/param/parsers.go | 149 +++ pkg/context/param/parsers_test.go | 84 ++ pkg/context/renderer.go | 12 + pkg/context/response.go | 27 + pkg/controller.go | 706 +++++++++++ pkg/controller_test.go | 181 +++ pkg/doc.go | 17 + pkg/error.go | 488 ++++++++ pkg/error_test.go | 88 ++ pkg/filter.go | 44 + pkg/filter_test.go | 68 ++ pkg/flash.go | 110 ++ pkg/flash_test.go | 54 + pkg/fs.go | 74 ++ pkg/grace/grace.go | 166 +++ pkg/grace/server.go | 356 ++++++ pkg/hooks.go | 104 ++ pkg/httplib/README.md | 97 ++ pkg/httplib/httplib.go | 654 ++++++++++ pkg/httplib/httplib_test.go | 286 +++++ pkg/log.go | 127 ++ pkg/logs/README.md | 72 ++ pkg/logs/accesslog.go | 83 ++ pkg/logs/alils/alils.go | 186 +++ pkg/logs/alils/config.go | 13 + pkg/logs/alils/log.pb.go | 1038 ++++++++++++++++ pkg/logs/alils/log_config.go | 42 + pkg/logs/alils/log_project.go | 819 +++++++++++++ pkg/logs/alils/log_store.go | 271 +++++ pkg/logs/alils/machine_group.go | 91 ++ pkg/logs/alils/request.go | 62 + pkg/logs/alils/signature.go | 111 ++ pkg/logs/conn.go | 119 ++ pkg/logs/conn_test.go | 79 ++ pkg/logs/console.go | 99 ++ pkg/logs/console_test.go | 64 + pkg/logs/es/es.go | 102 ++ pkg/logs/file.go | 409 +++++++ pkg/logs/file_test.go | 420 +++++++ pkg/logs/jianliao.go | 72 ++ pkg/logs/log.go | 669 +++++++++++ pkg/logs/logger.go | 176 +++ pkg/logs/logger_test.go | 57 + pkg/logs/multifile.go | 119 ++ pkg/logs/multifile_test.go | 78 ++ pkg/logs/slack.go | 60 + pkg/logs/smtp.go | 149 +++ pkg/logs/smtp_test.go | 27 + pkg/metric/prometheus.go | 99 ++ pkg/metric/prometheus_test.go | 42 + pkg/migration/ddl.go | 395 +++++++ pkg/migration/doc.go | 32 + pkg/migration/migration.go | 330 ++++++ pkg/mime.go | 556 +++++++++ pkg/namespace.go | 396 +++++++ pkg/namespace_test.go | 168 +++ pkg/parser.go | 591 +++++++++ pkg/plugins/apiauth/apiauth.go | 165 +++ pkg/plugins/apiauth/apiauth_test.go | 20 + pkg/plugins/auth/basic.go | 107 ++ pkg/plugins/authz/authz.go | 86 ++ pkg/plugins/authz/authz_model.conf | 14 + pkg/plugins/authz/authz_policy.csv | 7 + pkg/plugins/authz/authz_test.go | 107 ++ pkg/plugins/cors/cors.go | 228 ++++ pkg/plugins/cors/cors_test.go | 253 ++++ pkg/policy.go | 97 ++ pkg/router.go | 1052 +++++++++++++++++ pkg/router_test.go | 732 ++++++++++++ pkg/session/README.md | 114 ++ pkg/session/couchbase/sess_couchbase.go | 247 ++++ pkg/session/ledis/ledis_session.go | 173 +++ pkg/session/memcache/sess_memcache.go | 230 ++++ pkg/session/mysql/sess_mysql.go | 228 ++++ pkg/session/postgres/sess_postgresql.go | 243 ++++ pkg/session/redis/sess_redis.go | 261 ++++ pkg/session/redis_cluster/redis_cluster.go | 220 ++++ .../redis_sentinel/sess_redis_sentinel.go | 234 ++++ .../sess_redis_sentinel_test.go | 90 ++ pkg/session/sess_cookie.go | 180 +++ pkg/session/sess_cookie_test.go | 105 ++ pkg/session/sess_file.go | 315 +++++ pkg/session/sess_file_test.go | 387 ++++++ pkg/session/sess_mem.go | 196 +++ pkg/session/sess_mem_test.go | 58 + pkg/session/sess_test.go | 131 ++ pkg/session/sess_utils.go | 207 ++++ pkg/session/session.go | 377 ++++++ pkg/session/ssdb/sess_ssdb.go | 199 ++++ pkg/staticfile.go | 234 ++++ pkg/staticfile_test.go | 99 ++ pkg/swagger/swagger.go | 174 +++ pkg/template.go | 406 +++++++ pkg/template_test.go | 316 +++++ pkg/templatefunc.go | 780 ++++++++++++ pkg/templatefunc_test.go | 380 ++++++ pkg/testdata/Makefile | 2 + pkg/testdata/bindata.go | 296 +++++ pkg/testdata/views/blocks/block.tpl | 3 + pkg/testdata/views/header.tpl | 3 + pkg/testdata/views/index.tpl | 15 + pkg/testing/assertions.go | 15 + pkg/testing/client.go | 65 + pkg/toolbox/healthcheck.go | 48 + pkg/toolbox/profile.go | 184 +++ pkg/toolbox/profile_test.go | 28 + pkg/toolbox/statistics.go | 149 +++ pkg/toolbox/statistics_test.go | 40 + pkg/toolbox/task.go | 640 ++++++++++ pkg/toolbox/task_test.go | 85 ++ pkg/tree.go | 585 +++++++++ pkg/tree_test.go | 306 +++++ pkg/unregroute_test.go | 226 ++++ pkg/utils/caller.go | 25 + pkg/utils/caller_test.go | 28 + pkg/utils/captcha/LICENSE | 19 + pkg/utils/captcha/README.md | 45 + pkg/utils/captcha/captcha.go | 270 +++++ pkg/utils/captcha/image.go | 501 ++++++++ pkg/utils/captcha/image_test.go | 52 + pkg/utils/captcha/siprng.go | 277 +++++ pkg/utils/captcha/siprng_test.go | 33 + pkg/utils/debug.go | 478 ++++++++ pkg/utils/debug_test.go | 46 + pkg/utils/file.go | 101 ++ pkg/utils/file_test.go | 75 ++ pkg/utils/mail.go | 424 +++++++ pkg/utils/mail_test.go | 41 + pkg/utils/pagination/controller.go | 26 + pkg/utils/pagination/doc.go | 58 + pkg/utils/pagination/paginator.go | 189 +++ pkg/utils/pagination/utils.go | 34 + pkg/utils/rand.go | 44 + pkg/utils/rand_test.go | 33 + pkg/utils/safemap.go | 91 ++ pkg/utils/safemap_test.go | 89 ++ pkg/utils/slice.go | 170 +++ pkg/utils/slice_test.go | 29 + pkg/utils/testdata/grepe.test | 7 + pkg/utils/utils.go | 89 ++ pkg/utils/utils_test.go | 36 + pkg/validation/README.md | 147 +++ pkg/validation/util.go | 298 +++++ pkg/validation/util_test.go | 128 ++ pkg/validation/validation.go | 456 +++++++ pkg/validation/validation_test.go | 609 ++++++++++ pkg/validation/validators.go | 738 ++++++++++++ 194 files changed, 39077 insertions(+), 69 deletions(-) delete mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/admin.go create mode 100644 pkg/admin_test.go create mode 100644 pkg/adminui.go create mode 100644 pkg/app.go create mode 100644 pkg/beego.go create mode 100644 pkg/build_info.go create mode 100644 pkg/cache/README.md create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go create mode 100644 pkg/cache/conv.go create mode 100644 pkg/cache/conv_test.go create mode 100644 pkg/cache/file.go create mode 100644 pkg/cache/memcache/memcache.go create mode 100644 pkg/cache/memcache/memcache_test.go create mode 100644 pkg/cache/memory.go create mode 100644 pkg/cache/redis/redis.go create mode 100644 pkg/cache/redis/redis_test.go create mode 100644 pkg/cache/ssdb/ssdb.go create mode 100644 pkg/cache/ssdb/ssdb_test.go create mode 100644 pkg/config.go create mode 100644 pkg/config/config.go create mode 100644 pkg/config/config_test.go create mode 100644 pkg/config/env/env.go create mode 100644 pkg/config/env/env_test.go create mode 100644 pkg/config/fake.go create mode 100644 pkg/config/ini.go create mode 100644 pkg/config/ini_test.go create mode 100644 pkg/config/json.go create mode 100644 pkg/config/json_test.go create mode 100644 pkg/config/xml/xml.go create mode 100644 pkg/config/xml/xml_test.go create mode 100644 pkg/config/yaml/yaml.go create mode 100644 pkg/config/yaml/yaml_test.go create mode 100644 pkg/config_test.go create mode 100644 pkg/context/acceptencoder.go create mode 100644 pkg/context/acceptencoder_test.go create mode 100644 pkg/context/context.go create mode 100644 pkg/context/context_test.go create mode 100644 pkg/context/input.go create mode 100644 pkg/context/input_test.go create mode 100644 pkg/context/output.go create mode 100644 pkg/context/param/conv.go create mode 100644 pkg/context/param/methodparams.go create mode 100644 pkg/context/param/options.go create mode 100644 pkg/context/param/parsers.go create mode 100644 pkg/context/param/parsers_test.go create mode 100644 pkg/context/renderer.go create mode 100644 pkg/context/response.go create mode 100644 pkg/controller.go create mode 100644 pkg/controller_test.go create mode 100644 pkg/doc.go create mode 100644 pkg/error.go create mode 100644 pkg/error_test.go create mode 100644 pkg/filter.go create mode 100644 pkg/filter_test.go create mode 100644 pkg/flash.go create mode 100644 pkg/flash_test.go create mode 100644 pkg/fs.go create mode 100644 pkg/grace/grace.go create mode 100644 pkg/grace/server.go create mode 100644 pkg/hooks.go create mode 100644 pkg/httplib/README.md create mode 100644 pkg/httplib/httplib.go create mode 100644 pkg/httplib/httplib_test.go create mode 100644 pkg/log.go create mode 100644 pkg/logs/README.md create mode 100644 pkg/logs/accesslog.go create mode 100644 pkg/logs/alils/alils.go create mode 100755 pkg/logs/alils/config.go create mode 100755 pkg/logs/alils/log.pb.go create mode 100755 pkg/logs/alils/log_config.go create mode 100755 pkg/logs/alils/log_project.go create mode 100755 pkg/logs/alils/log_store.go create mode 100755 pkg/logs/alils/machine_group.go create mode 100755 pkg/logs/alils/request.go create mode 100755 pkg/logs/alils/signature.go create mode 100644 pkg/logs/conn.go create mode 100644 pkg/logs/conn_test.go create mode 100644 pkg/logs/console.go create mode 100644 pkg/logs/console_test.go create mode 100644 pkg/logs/es/es.go create mode 100644 pkg/logs/file.go create mode 100644 pkg/logs/file_test.go create mode 100644 pkg/logs/jianliao.go create mode 100644 pkg/logs/log.go create mode 100644 pkg/logs/logger.go create mode 100644 pkg/logs/logger_test.go create mode 100644 pkg/logs/multifile.go create mode 100644 pkg/logs/multifile_test.go create mode 100644 pkg/logs/slack.go create mode 100644 pkg/logs/smtp.go create mode 100644 pkg/logs/smtp_test.go create mode 100644 pkg/metric/prometheus.go create mode 100644 pkg/metric/prometheus_test.go create mode 100644 pkg/migration/ddl.go create mode 100644 pkg/migration/doc.go create mode 100644 pkg/migration/migration.go create mode 100644 pkg/mime.go create mode 100644 pkg/namespace.go create mode 100644 pkg/namespace_test.go create mode 100644 pkg/parser.go create mode 100644 pkg/plugins/apiauth/apiauth.go create mode 100644 pkg/plugins/apiauth/apiauth_test.go create mode 100644 pkg/plugins/auth/basic.go create mode 100644 pkg/plugins/authz/authz.go create mode 100644 pkg/plugins/authz/authz_model.conf create mode 100644 pkg/plugins/authz/authz_policy.csv create mode 100644 pkg/plugins/authz/authz_test.go create mode 100644 pkg/plugins/cors/cors.go create mode 100644 pkg/plugins/cors/cors_test.go create mode 100644 pkg/policy.go create mode 100644 pkg/router.go create mode 100644 pkg/router_test.go create mode 100644 pkg/session/README.md create mode 100644 pkg/session/couchbase/sess_couchbase.go create mode 100644 pkg/session/ledis/ledis_session.go create mode 100644 pkg/session/memcache/sess_memcache.go create mode 100644 pkg/session/mysql/sess_mysql.go create mode 100644 pkg/session/postgres/sess_postgresql.go create mode 100644 pkg/session/redis/sess_redis.go create mode 100644 pkg/session/redis_cluster/redis_cluster.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/session/sess_cookie.go create mode 100644 pkg/session/sess_cookie_test.go create mode 100644 pkg/session/sess_file.go create mode 100644 pkg/session/sess_file_test.go create mode 100644 pkg/session/sess_mem.go create mode 100644 pkg/session/sess_mem_test.go create mode 100644 pkg/session/sess_test.go create mode 100644 pkg/session/sess_utils.go create mode 100644 pkg/session/session.go create mode 100644 pkg/session/ssdb/sess_ssdb.go create mode 100644 pkg/staticfile.go create mode 100644 pkg/staticfile_test.go create mode 100644 pkg/swagger/swagger.go create mode 100644 pkg/template.go create mode 100644 pkg/template_test.go create mode 100644 pkg/templatefunc.go create mode 100644 pkg/templatefunc_test.go create mode 100644 pkg/testdata/Makefile create mode 100644 pkg/testdata/bindata.go create mode 100644 pkg/testdata/views/blocks/block.tpl create mode 100644 pkg/testdata/views/header.tpl create mode 100644 pkg/testdata/views/index.tpl create mode 100644 pkg/testing/assertions.go create mode 100644 pkg/testing/client.go create mode 100644 pkg/toolbox/healthcheck.go create mode 100644 pkg/toolbox/profile.go create mode 100644 pkg/toolbox/profile_test.go create mode 100644 pkg/toolbox/statistics.go create mode 100644 pkg/toolbox/statistics_test.go create mode 100644 pkg/toolbox/task.go create mode 100644 pkg/toolbox/task_test.go create mode 100644 pkg/tree.go create mode 100644 pkg/tree_test.go create mode 100644 pkg/unregroute_test.go create mode 100644 pkg/utils/caller.go create mode 100644 pkg/utils/caller_test.go create mode 100644 pkg/utils/captcha/LICENSE create mode 100644 pkg/utils/captcha/README.md create mode 100644 pkg/utils/captcha/captcha.go create mode 100644 pkg/utils/captcha/image.go create mode 100644 pkg/utils/captcha/image_test.go create mode 100644 pkg/utils/captcha/siprng.go create mode 100644 pkg/utils/captcha/siprng_test.go create mode 100644 pkg/utils/debug.go create mode 100644 pkg/utils/debug_test.go create mode 100644 pkg/utils/file.go create mode 100644 pkg/utils/file_test.go create mode 100644 pkg/utils/mail.go create mode 100644 pkg/utils/mail_test.go create mode 100644 pkg/utils/pagination/controller.go create mode 100644 pkg/utils/pagination/doc.go create mode 100644 pkg/utils/pagination/paginator.go create mode 100644 pkg/utils/pagination/utils.go create mode 100644 pkg/utils/rand.go create mode 100644 pkg/utils/rand_test.go create mode 100644 pkg/utils/safemap.go create mode 100644 pkg/utils/safemap_test.go create mode 100644 pkg/utils/slice.go create mode 100644 pkg/utils/slice_test.go create mode 100644 pkg/utils/testdata/grepe.test create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go create mode 100644 pkg/validation/README.md create mode 100644 pkg/validation/util.go create mode 100644 pkg/validation/util_test.go create mode 100644 pkg/validation/validation.go create mode 100644 pkg/validation/validation_test.go create mode 100644 pkg/validation/validators.go diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index eac85091..61f17346 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - // if fi.initial.String() != "" { + //if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - // } + //} // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index a84070b4..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,21 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" + lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" - - lru "github.com/hashicorp/golang-lru" - - "github.com/astaxie/beego/pkg/common" - orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -68,7 +63,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, // https://github.com/rana/ora + "ora": DROracle, //https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -124,7 +119,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -// su must call release to release *sql.Stmt after using +//su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -294,26 +289,82 @@ func detectTZ(al *alias) { } } +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + // AddAliasWthDB add a aliasName for the drivename -// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - return orm2.AddAliasWthDB(aliasName, driverName, db) + _, err := addAliasWthDB(aliasName, driverName, db) + return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - kvs := make([]common.KV, 0, 2) + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + for i, v := range params { switch i { case 0: - kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) + SetMaxIdleConns(al.Name, v) case 1: - kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) - case 2: - kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) + SetMaxOpenConns(al.Name, v) } } - return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -373,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -393,7 +444,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -// garbage recycle for stmt +//garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go deleted file mode 100644 index d7724527..00000000 --- a/orm/orm_alias_adapt_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "os" - "testing" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" -) - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) - assert.Nil(t, err) -} diff --git a/pkg/admin.go b/pkg/admin.go new file mode 100644 index 00000000..db52647e --- /dev/null +++ b/pkg/admin.go @@ -0,0 +1,458 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "reflect" + "strconv" + "text/template" + "time" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// BeeAdminApp is the default adminApp used by admin module. +var beeAdminApp *adminApp + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + beeAdminApp = &adminApp{ + routers: make(map[string]http.HandlerFunc), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Route("/", adminIndex) + beeAdminApp.Route("/qps", qpsIndex) + beeAdminApp.Route("/prof", profIndex) + beeAdminApp.Route("/healthcheck", healthcheck) + beeAdminApp.Route("/task", taskStatus) + beeAdminApp.Route("/listconf", listConf) + beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } +} + +// AdminIndex is the default http.Handler for admin module. +// it matches url pattern "/". +func adminIndex(rw http.ResponseWriter, _ *http.Request) { + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func qpsIndex(rw http.ResponseWriter, _ *http.Request) { + data := make(map[interface{}]interface{}) + data["Content"] = toolbox.StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func listConf(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + filterTypes = []string{} + filterTypeData = make(M) + ) + + if BeeApp.Handlers.enableFilter { + var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router"} { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { + filterType = fr + filterTypes = append(filterTypes, filterType) + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[filterType] = resultList + } + } + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func list(root string, p interface{}, m M) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + +// PrintTree prints all registered routers. +func PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range BeeApp.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + var result = []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func profIndex(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + toolbox.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func healthcheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range toolbox.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response + +} + +func writeJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func taskStatus(rw http.ResponseWriter, req *http.Request) { + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + if t, ok := toolbox.AdminTaskList[taskname]; ok { + if err := t.Run(); err != nil { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} + } + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + } else { + data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + } + } + + // List Tasks + content := make(M) + resultList := new([][]string) + var fields = []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + for tname, tk := range toolbox.AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec()), + template.HTMLEscapeString(tk.GetStatus()), + template.HTMLEscapeString(tk.GetPrev().String()), + } + *resultList = append(*resultList, result) + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +// adminApp is an http.HandlerFunc map used as beeAdminApp. +type adminApp struct { + routers map[string]http.HandlerFunc +} + +// Route adds http.HandlerFunc to adminApp with url pattern. +func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { + admin.routers[pattern] = f +} + +// Run adminApp http server. +// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. +func (admin *adminApp) Run() { + if len(toolbox.AdminTaskList) > 0 { + toolbox.StartTask() + } + addr := BConfig.Listen.AdminAddr + + if BConfig.Listen.AdminPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) + } + for p, f := range admin.routers { + http.Handle(p, f) + } + logs.Info("Admin server Running on %s", addr) + + var err error + if BConfig.Listen.Graceful { + err = grace.ListenAndServe(addr, nil) + } else { + err = http.ListenAndServe(addr, nil) + } + if err != nil { + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + } +} diff --git a/pkg/admin_test.go b/pkg/admin_test.go new file mode 100644 index 00000000..3f3612e4 --- /dev/null +++ b/pkg/admin_test.go @@ -0,0 +1,239 @@ +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/astaxie/beego/toolbox" +) + +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + +func TestList_01(t *testing.T) { + m := make(M) + list("BConfig", BConfig, m) + t.Log(m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k]) != fmt.Sprint(v) { + t.Log(k, "old-key", v, "new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() M { + m := make(M) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize + m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs + m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} + +func TestWriteJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + writeJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Could not decode response body into slice.") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { + t.Errorf("handler returned unexpected body: got %v want %v", + decodedResponseBody, expectedResponseBody) + } + +} diff --git a/pkg/adminui.go b/pkg/adminui.go new file mode 100644 index 00000000..cdcdef33 --- /dev/null +++ b/pkg/adminui.go @@ -0,0 +1,356 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var indexTpl = ` +{{define "content"}} +

Beego Admin Dashboard

+

+For detail usage please check our document: +

+

+Toolbox +

+

+Live Monitor +

+{{.Content}} +{{end}}` + +var profillingTpl = ` +{{define "content"}} +

{{.Title}}

+
+
{{.Content}}
+
+{{end}}` + +var defaultScriptsTpl = `` + +var gcAjaxTpl = ` +{{define "scripts"}} + +{{end}} +` + +var qpsTpl = `{{define "content"}} +

Requests statistics

+ + + + {{range .Content.Fields}} + + {{end}} + + + + + {{range $i, $elem := .Content.Data}} + + + + + + + + + + + {{end}} + + +
+ {{.}} +
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
+{{end}}` + +var configTpl = ` +{{define "content"}} +

Configurations

+
+{{range $index, $elem := .Content}}
+{{$index}}={{$elem}}
+{{end}}
+
+{{end}} +` + +var routerAndFilterTpl = `{{define "content"}} + + +

{{.Title}}

+ +{{range .Content.Methods}} + +
+
{{.}}
+
+ + + + {{range $.Content.Fields}} + + {{end}} + + + + + {{$slice := index $.Content.Data .}} + {{range $i, $elem := $slice}} + + + {{range $elem}} + + {{end}} + + + {{end}} + + +
+ {{.}} +
+ {{.}} +
+
+
+{{end}} + + +{{end}}` + +var tasksTpl = `{{define "content"}} + +

{{.Title}}

+ +{{if .Message }} +{{ $messageType := index .Message 0}} +

+{{index .Message 1}} +

+{{end}} + + + + + +{{range .Content.Fields}} + +{{end}} + + + + +{{range $i, $slice := .Content.Data}} + + {{range $slice}} + + {{end}} + + +{{end}} + +
+{{.}} +
+ {{.}} + + Run +
+ +{{end}}` + +var healthCheckTpl = ` +{{define "content"}} + +

{{.Title}}

+ + + +{{range .Content.Fields}} + +{{end}} + + + +{{range $i, $slice := .Content.Data}} + {{ $header := index $slice 0}} + {{ if eq "success" $header}} + + {{else if eq "error" $header}} + + {{else}} + + {{end}} + {{range $j, $elem := $slice}} + {{if ne $j 0}} + + {{end}} + {{end}} + + +{{end}} + + +
+ {{.}} +
+ {{$elem}} + + {{$header}} +
+{{end}}` + +// The base dashboardTpl +var dashboardTpl = ` + + + + + + + + + + +Welcome to Beego Admin Dashboard + + + + + + + + + + + + + +
+{{template "content" .}} +
+ + + + + + + +{{template "scripts" .}} + + +` diff --git a/pkg/app.go b/pkg/app.go new file mode 100644 index 00000000..f3fe6f7b --- /dev/null +++ b/pkg/app.go @@ -0,0 +1,496 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/fcgi" + "os" + "path" + "strings" + "time" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" + "golang.org/x/crypto/acme/autocert" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = NewApp() +} + +// App defines beego application with a new PatternServeMux. +type App struct { + Handlers *ControllerRegister + Server *http.Server +} + +// NewApp returns a new beego application. +func NewApp() *App { + cr := NewControllerRegister() + app := &App{Handlers: cr, Server: &http.Server{}} + return app +} + +// MiddleWare function for http.Handler +type MiddleWare func(http.Handler) http.Handler + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + addr := BConfig.Listen.HTTPAddr + + if BConfig.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) + } + + var ( + err error + l net.Listener + endRunning = make(chan bool, 1) + ) + + // run cgi server + if BConfig.Listen.EnableFcgi { + if BConfig.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O + logs.Info("Use FCGI via standard I/O") + } else { + logs.Critical("Cannot use FCGI via standard I/O", err) + } + return + } + if BConfig.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + logs.Critical("Listen: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + logs.Critical("fcgi.Serve: ", err) + } + return + } + + app.Server.Handler = app.Handlers + for i := len(mws) - 1; i >= 0; i-- { + if mws[i] == nil { + continue + } + app.Server.Handler = mws[i](app.Server.Handler) + } + app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") + + // run graceful mode + if BConfig.Listen.Graceful { + httpsAddr := BConfig.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.EnableMutualHTTPS { + if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } else { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } + endRunning <- true + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.ListenTCP4 { + server.Network = "tcp4" + } + if err := server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + <-endRunning + return + } + + // run normal mode + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } else if BConfig.Listen.EnableHTTP { + logs.Info("Start https server error, conflict with http. Please reset https port") + return + } + logs.Info("https server Running on https://%s", app.Server.Addr) + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } else if BConfig.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + if err != nil { + logs.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + } + if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + + } + if BConfig.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + logs.Info("http server Running on http://%s", app.Server.Addr) + if BConfig.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + if err := app.Server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } + <-endRunning +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + BeeApp.Handlers.Add(rootpath, c, mappingMethods...) + return BeeApp +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + subPaths := splitPath(fixedRoute) + if method == "" || method == "*" { + for m := range HTTPMETHOD { + if _, ok := BeeApp.Handlers.routers[m]; !ok { + continue + } + if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) + continue + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) + } + return BeeApp + } + // Single HTTP method + um := strings.ToUpper(method) + if _, ok := BeeApp.Handlers.routers[um]; ok { + if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) + return BeeApp + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) + } + return BeeApp +} + +func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { + for i := range entryPointTree.fixrouters { + if entryPointTree.fixrouters[i].prefix == paths[0] { + if len(paths) == 1 { + if len(entryPointTree.fixrouters[i].fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.fixrouters[i].leaves) > 0 { + entryPointTree.fixrouters[i].leaves[0] = nil + entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] + } + } else { + // Remove the *Tree from the fixrouters slice + entryPointTree.fixrouters[i] = nil + + if i == len(entryPointTree.fixrouters)-1 { + entryPointTree.fixrouters = entryPointTree.fixrouters[:i] + } else { + entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) + } + } + return + } + findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) + } + } +} + +func findAndRemoveSingleTree(entryPointTree *Tree) { + if entryPointTree == nil { + return + } + if len(entryPointTree.fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.leaves) > 0 { + entryPointTree.leaves[0] = nil + entryPointTree.leaves = entryPointTree.leaves[1:] + } + } +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +//} +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + BeeApp.Handlers.Include(cList...) + return BeeApp +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + Router(rootpath, c) + Router(path.Join(rootpath, ":objectId"), c) + return BeeApp +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + BeeApp.Handlers.AddAuto(c) + return BeeApp +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.Handlers.AddAutoPrefix(prefix, c) + return BeeApp +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Get(rootpath, f) + return BeeApp +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Post(rootpath, f) + return BeeApp +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Delete(rootpath, f) + return BeeApp +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Put(rootpath, f) + return BeeApp +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Head(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Patch(rootpath, f) + return BeeApp +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Any(rootpath, f) + return BeeApp +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + BeeApp.Handlers.Handler(rootpath, h, options...) + return BeeApp +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) + return BeeApp +} diff --git a/pkg/beego.go b/pkg/beego.go new file mode 100644 index 00000000..8ebe0bab --- /dev/null +++ b/pkg/beego.go @@ -0,0 +1,123 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "os" + "path/filepath" + "strconv" + "strings" +) + +const ( + // VERSION represent beego web framework version. + VERSION = "1.12.2" + + // DEV is for develop + DEV = "dev" + // PROD is for production + PROD = "prod" +) + +// M is Map shortcut +type M map[string]interface{} + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + hooks = append(hooks, hf...) +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + + initBeforeHTTPRun() + + if len(params) > 0 && params[0] != "" { + strs := strings.Split(params[0], ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BConfig.Listen.Domains = params + } + + BeeApp.Run() +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + initBeforeHTTPRun() + + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + BConfig.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BeeApp.Run(mws...) +} + +func initBeforeHTTPRun() { + //init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } + } +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + path := filepath.Join(ap, "conf", "app.conf") + os.Chdir(ap) + InitBeegoBeforeTest(path) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { + panic(err) + } + BConfig.RunMode = "test" + initBeforeHTTPRun() +} diff --git a/pkg/build_info.go b/pkg/build_info.go new file mode 100644 index 00000000..6dc2835e --- /dev/null +++ b/pkg/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/cache/README.md b/pkg/cache/README.md new file mode 100644 index 00000000..b467760a --- /dev/null +++ b/pkg/cache/README.md @@ -0,0 +1,59 @@ +## cache +cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/cache + + +## What adapters are supported? + +As of now this cache support memory, Memcache and Redis. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/cache" + ) + +Then init a Cache (example with memory adapter) + + bm, err := cache.NewCache("memory", `{"interval":60}`) + +Use it like this: + + bm.Put("astaxie", 1, 10 * time.Second) + bm.Get("astaxie") + bm.IsExist("astaxie") + bm.Delete("astaxie") + + +## Memory adapter + +Configure memory adapter like this: + + {"interval":60} + +interval means the gc time. The cache will check at each time interval, whether item has expired. + + +## Memcache adapter + +Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. + +Configure like this: + + {"conn":"127.0.0.1:11211"} + + +## Redis adapter + +Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. + +Configure like this: + + {"conn":":6039"} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 00000000..82585c4e --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,103 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + "time" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/cache/conv.go b/pkg/cache/conv.go new file mode 100644 index 00000000..87800586 --- /dev/null +++ b/pkg/cache/conv.go @@ -0,0 +1,100 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/pkg/cache/conv_test.go b/pkg/cache/conv_test.go new file mode 100644 index 00000000..b90e224a --- /dev/null +++ b/pkg/cache/conv_test.go @@ -0,0 +1,143 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" +) + +func TestGetString(t *testing.T) { + var t1 = "test1" + if "test1" != GetString(t1) { + t.Error("get string from string error") + } + var t2 = []byte("test2") + if "test2" != GetString(t2) { + t.Error("get string from byte array error") + } + var t3 = 1 + if "1" != GetString(t3) { + t.Error("get string from int error") + } + var t4 int64 = 1 + if "1" != GetString(t4) { + t.Error("get string from int64 error") + } + var t5 = 1.1 + if "1.1" != GetString(t5) { + t.Error("get string from float64 error") + } + + if "" != GetString(nil) { + t.Error("get string from nil error") + } +} + +func TestGetInt(t *testing.T) { + var t1 = 1 + if 1 != GetInt(t1) { + t.Error("get int from int error") + } + var t2 int32 = 32 + if 32 != GetInt(t2) { + t.Error("get int from int32 error") + } + var t3 int64 = 64 + if 64 != GetInt(t3) { + t.Error("get int from int64 error") + } + var t4 = "128" + if 128 != GetInt(t4) { + t.Error("get int from num string error") + } + if 0 != GetInt(nil) { + t.Error("get int from nil error") + } +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + var t1 = 1 + if i != GetInt64(t1) { + t.Error("get int64 from int error") + } + var t2 int32 = 1 + if i != GetInt64(t2) { + t.Error("get int64 from int32 error") + } + var t3 int64 = 1 + if i != GetInt64(t3) { + t.Error("get int64 from int64 error") + } + var t4 = "1" + if i != GetInt64(t4) { + t.Error("get int64 from num string error") + } + if 0 != GetInt64(nil) { + t.Error("get int64 from nil") + } +} + +func TestGetFloat64(t *testing.T) { + var f = 1.11 + var t1 float32 = 1.11 + if f != GetFloat64(t1) { + t.Error("get float64 from float32 error") + } + var t2 = 1.11 + if f != GetFloat64(t2) { + t.Error("get float64 from float64 error") + } + var t3 = "1.11" + if f != GetFloat64(t3) { + t.Error("get float64 from string error") + } + + var f2 float64 = 1 + var t4 = 1 + if f2 != GetFloat64(t4) { + t.Error("get float64 from int error") + } + + if 0 != GetFloat64(nil) { + t.Error("get float64 from nil error") + } +} + +func TestGetBool(t *testing.T) { + var t1 = true + if !GetBool(t1) { + t.Error("get bool from bool error") + } + var t2 = "true" + if !GetBool(t2) { + t.Error("get bool from string error") + } + if GetBool(nil) { + t.Error("get bool from nil error") + } +} + +func byteArrayEquals(a []byte, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/pkg/cache/file.go b/pkg/cache/file.go new file mode 100644 index 00000000..6f12d3ee --- /dev/null +++ b/pkg/cache/file.go @@ -0,0 +1,258 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + Lastaccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +func (fc *FileCache) StartAndGC(config string) error { + + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == time.Duration(fc.EmbedExpiry) { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.Lastaccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go new file mode 100644 index 00000000..19116bfa --- /dev/null +++ b/pkg/cache/memcache/memcache.go @@ -0,0 +1,188 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/memcache" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package memcache + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/astaxie/beego/cache" + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache Memcache adapter. +type Cache struct { + conn *memcache.Client + conninfo []string +} + +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + if item, err := rc.conn.Get(key); err == nil { + return item.Value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var rv []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv + } + } + mv, err := rc.conn.GetMulti(keys) + if err == nil { + for _, v := range mv { + rv = append(rv, v.Value) + } + return rv + } + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv +} + +// Put put value to memcache. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} + if v, ok := val.([]byte); ok { + item.Value = v + } else if str, ok := val.(string); ok { + item.Value = []byte(str) + } else { + return errors.New("val only support string and []byte") + } + return rc.conn.Set(&item) +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.Delete(key) +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Increment(key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Decrement(key, 1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + _, err := rc.conn.Get(key) + return err == nil +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.FlushAll() +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + rc.conn = memcache.New(rc.conninfo...) + return nil +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go new file mode 100644 index 00000000..d9129b69 --- /dev/null +++ b/pkg/cache/memcache/memcache_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + _ "github.com/bradfitz/gomemcache/memcache" + + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestMemcacheCache(t *testing.T) { + bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie").([]byte); string(v) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go new file mode 100644 index 00000000..d8314e3c --- /dev/null +++ b/pkg/cache/memory.go @@ -0,0 +1,256 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val + 1 + case int32: + itm.val = val + 1 + case int64: + itm.val = val + 1 + case uint: + itm.val = val + 1 + case uint32: + itm.val = val + 1 + case uint64: + itm.val = val + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val - 1 + case int64: + itm.val = val - 1 + case int32: + itm.val = val - 1 + case uint: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + bc.RLock() + if bc.items == nil { + bc.RUnlock() + return + } + bc.RUnlock() + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 00000000..56faf211 --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,272 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/redis" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package redis + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/astaxie/beego/cache" + "strings" +) + +var ( + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" +) + +// Cache is Redis cache adapter. +type Cache struct { + p *redis.Pool // redis connection pool + conninfo string + dbNum int + key string + password string + maxIdle int + + //the timeout to a value less than the redis server's timeout. + timeout time.Duration +} + +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} +} + +// actually do the redis cmds, args[0] must be the key name. +func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { + if len(args) < 1 { + return nil, errors.New("missing required arguments") + } + args[0] = rc.associate(args[0]) + c := rc.p.Get() + defer c.Close() + + return c.Do(commandName, args...) +} + +// associate with config key. +func (rc *Cache) associate(originKey interface{}) string { + return fmt.Sprintf("%s:%s", rc.key, originKey) +} + +// Get cache from redis. +func (rc *Cache) Get(key string) interface{} { + if v, err := rc.do("GET", key); err == nil { + return v + } + return nil +} + +// GetMulti get cache from redis. +func (rc *Cache) GetMulti(keys []string) []interface{} { + c := rc.p.Get() + defer c.Close() + var args []interface{} + for _, key := range keys { + args = append(args, rc.associate(key)) + } + values, err := redis.Values(c.Do("MGET", args...)) + if err != nil { + return nil + } + return values +} + +// Put put cache to redis. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) + return err +} + +// Delete delete cache in redis. +func (rc *Cache) Delete(key string) error { + _, err := rc.do("DEL", key) + return err +} + +// IsExist check cache's existence in redis. +func (rc *Cache) IsExist(key string) bool { + v, err := redis.Bool(rc.do("EXISTS", key)) + if err != nil { + return false + } + return v +} + +// Incr increase counter in redis. +func (rc *Cache) Incr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, 1)) + return err +} + +// Decr decrease counter in redis. +func (rc *Cache) Decr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, -1)) + return err +} + +// ClearAll clean all cache in redis. delete this redis collection. +func (rc *Cache) ClearAll() error { + cachedKeys, err := rc.Scan(rc.key + ":*") + if err != nil { + return err + } + c := rc.p.Get() + defer c.Close() + for _, str := range cachedKeys { + if _, err = c.Do("DEL", str); err != nil { + return err + } + } + return err +} + +// Scan scan all keys matching the pattern. a better choice than `keys` +func (rc *Cache) Scan(pattern string) (keys []string, err error) { + c := rc.p.Get() + defer c.Close() + var ( + cursor uint64 = 0 // start + result []interface{} + list []string + ) + for { + result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) + if err != nil { + return + } + list, err = redis.Strings(result[1], nil) + if err != nil { + return + } + keys = append(keys, list...) + cursor, err = redis.Uint64(result[0], nil) + if err != nil { + return + } + if cursor == 0 { // over + return + } + } +} + +// StartAndGC start redis cache adapter. +// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} +// the cache item in redis are stored forever, +// so no gc operation. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + + if _, ok := cf["key"]; !ok { + cf["key"] = DefaultKey + } + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + + // Format redis://@: + cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) + if i := strings.Index(cf["conn"], "@"); i > -1 { + cf["password"] = cf["conn"][0:i] + cf["conn"] = cf["conn"][i+1:] + } + + if _, ok := cf["dbNum"]; !ok { + cf["dbNum"] = "0" + } + if _, ok := cf["password"]; !ok { + cf["password"] = "" + } + if _, ok := cf["maxIdle"]; !ok { + cf["maxIdle"] = "3" + } + if _, ok := cf["timeout"]; !ok { + cf["timeout"] = "180s" + } + rc.key = cf["key"] + rc.conninfo = cf["conn"] + rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) + rc.password = cf["password"] + rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) + + if v, err := time.ParseDuration(cf["timeout"]); err == nil { + rc.timeout = v + } else { + rc.timeout = 180 * time.Second + } + + rc.connectInit() + + c := rc.p.Get() + defer c.Close() + + return c.Err() +} + +// connect to redis. +func (rc *Cache) connectInit() { + dialFunc := func() (c redis.Conn, err error) { + c, err = redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, err + } + + if rc.password != "" { + if _, err := c.Do("AUTH", rc.password); err != nil { + c.Close() + return nil, err + } + } + + _, selecterr := c.Do("SELECT", rc.dbNum) + if selecterr != nil { + c.Close() + return nil, selecterr + } + return + } + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: rc.maxIdle, + IdleTimeout: rc.timeout, + Dial: dialFunc, + } +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 00000000..60a19180 --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,144 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "testing" + "time" + + "github.com/astaxie/beego/cache" + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" +) + +func TestRedisCache(t *testing.T) { + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[0], nil); v != "author" { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[1], nil); v != "author1" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} + +func TestCache_Scan(t *testing.T) { + timeoutDuration := 10 * time.Second + // init + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + // insert all + for i := 0; i < 10000; i++ { + if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + t.Error("set Error", err) + } + } + // scan all for the first time + keys, err := bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + + assert.Equal(t, 10000, len(keys), "scan all error") + + // clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } + + // scan all for the second time + keys, err = bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + if len(keys) != 0 { + t.Error("scan all err") + } +} diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go new file mode 100644 index 00000000..fa2ce04b --- /dev/null +++ b/pkg/cache/ssdb/ssdb.go @@ -0,0 +1,231 @@ +package ssdb + +import ( + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "github.com/ssdb/gossdb/ssdb" + + "github.com/astaxie/beego/cache" +) + +// Cache SSDB adapter +type Cache struct { + conn *ssdb.Client + conninfo []string +} + +//NewSsdbCache create new ssdb adapter. +func NewSsdbCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil + } + } + value, err := rc.conn.Get(key) + if err == nil { + return value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var values []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + values = append(values, err) + } + return values + } + } + res, err := rc.conn.Do("multi_get", keys) + resSize := len(res) + if err == nil { + for i := 1; i < resSize; i += 2 { + values = append(values, res[i+1]) + } + return values + } + for i := 0; i < size; i++ { + values = append(values, err) + } + return values +} + +// DelMulti get value from memcache. +func (rc *Cache) DelMulti(keys []string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("multi_del", keys) + return err +} + +// Put put value to memcache. only support string. +func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + v, ok := value.(string) + if !ok { + return errors.New("value must string") + } + var resp []string + var err error + ttl := int(timeout / time.Second) + if ttl < 0 { + resp, err = rc.conn.Do("set", key, v) + } else { + resp, err = rc.conn.Do("setx", key, v, ttl) + } + if err != nil { + return err + } + if len(resp) == 2 && resp[0] == "ok" { + return nil + } + return errors.New("bad response") +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Del(key) + return err +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, -1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + resp, err := rc.conn.Do("exists", key) + if err != nil { + return false + } + if len(resp) == 2 && resp[1] == "1" { + return true + } + return false + +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + keyStart, keyEnd, limit := "", "", 50 + resp, err := rc.Scan(keyStart, keyEnd, limit) + for err == nil { + size := len(resp) + if size == 1 { + return nil + } + keys := []string{} + for i := 1; i < size; i += 2 { + keys = append(keys, resp[i]) + } + _, e := rc.conn.Do("multi_del", keys) + if e != nil { + return e + } + keyStart = resp[size-2] + resp, err = rc.Scan(keyStart, keyEnd, limit) + } + return err +} + +// Scan key all cached in ssdb. +func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil, err + } + } + resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) + if err != nil { + return nil, err + } + return resp, nil +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + conninfoArray := strings.Split(rc.conninfo[0], ":") + host := conninfoArray[0] + port, e := strconv.Atoi(conninfoArray[1]) + if e != nil { + return e + } + var err error + rc.conn, err = ssdb.Connect(host, port) + return err +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/cache/ssdb/ssdb_test.go new file mode 100644 index 00000000..dd474960 --- /dev/null +++ b/pkg/cache/ssdb/ssdb_test.go @@ -0,0 +1,104 @@ +package ssdb + +import ( + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) + if err != nil { + t.Error("init err") + } + + // test put and exist + if ssdb.IsExist("ssdb") { + t.Error("check err") + } + timeoutDuration := 10 * time.Second + //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + + // Get test done + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v := ssdb.Get("ssdb"); v != "ssdb" { + t.Error("get Error") + } + + //inc/dec test done + if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if err = ssdb.Incr("ssdb"); err != nil { + t.Error("incr Error", err) + } + + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + + if err = ssdb.Decr("ssdb"); err != nil { + t.Error("decr error") + } + + // test del + if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + if err := ssdb.Delete("ssdb"); err == nil { + if ssdb.IsExist("ssdb") { + t.Error("delete err") + } + } + + //test string + if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + if v := ssdb.Get("ssdb").(string); v != "ssdb" { + t.Error("get err") + } + + //test GetMulti done + if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb1") { + t.Error("check err") + } + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + if len(vv) != 2 { + t.Error("getmulti error") + } + if vv[0].(string) != "ssdb" { + t.Error("getmulti error") + } + if vv[1].(string) != "ssdb1" { + t.Error("getmulti error") + } + + // test clear all done + if err = ssdb.ClearAll(); err != nil { + t.Error("clear all err") + } + if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + t.Error("check err") + } +} diff --git a/pkg/config.go b/pkg/config.go new file mode 100644 index 00000000..b6c9a99c --- /dev/null +++ b/pkg/config.go @@ -0,0 +1,524 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" + "github.com/astaxie/beego/utils" +) + +// Config is the main struct for BConfig +type Config struct { + AppName string //Application name + RunMode string //Running Mode: dev | prod + RouterCaseSensitive bool + ServerName string + RecoverPanic bool + RecoverFunc func(*context.Context) + CopyRequestBody bool + EnableGzip bool + MaxMemory int64 + EnableErrorsShow bool + EnableErrorsRender bool + Listen Listen + WebConfig WebConfig + Log LogConfig +} + +// Listen holds for http and https related config +type Listen struct { + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + AutoTLS bool + Domains []string + TLSCacheDir string + EnableHTTPS bool + EnableMutualHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + TrustCaFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O +} + +// WebConfig holds web related config +type WebConfig struct { + AutoRender bool + EnableDocs bool + FlashName string + FlashSeparator string + DirectoryIndex bool + StaticDir map[string]string + StaticExtensionsToGzip []string + StaticCacheFileSize int + StaticCacheFileNum int + TemplateLeft string + TemplateRight string + ViewsPath string + EnableXSRF bool + XSRFKey string + XSRFExpire int + Session SessionConfig +} + +// SessionConfig holds session related config +type SessionConfig struct { + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params +} + +// LogConfig holds Log related config +type LogConfig struct { + AccessLogs bool + EnableStaticLogs bool //log static files requests default: false + AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + FileLineNum bool + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + WorkPath, err = os.Getwd() + if err != nil { + panic(err) + } + var filename = "app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" + } + appConfigPath = filepath.Join(WorkPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} + +func recoverPanic(ctx *context.Context) { + if err := recover(); err != nil { + if err == ErrAbort { + return + } + if !BConfig.RecoverPanic { + panic(err) + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), ctx) + return + } + } + var stack string + logs.Critical("the request url is ", ctx.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { + showErr(err, ctx, stack) + } + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } + } +} + +func newBConfig() *Config { + return &Config{ + AppName: "beego", + RunMode: PROD, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + VERSION, + RecoverPanic: true, + RecoverFunc: recoverPanic, + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, //64MB + EnableErrorsShow: true, + EnableErrorsRender: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + StaticCacheFileSize: 1024 * 100, + StaticCacheFileNum: 1000, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + }, + }, + Log: LogConfig{ + AccessLogs: false, + EnableStaticLogs: false, + AccessLogsFormat: "APACHE_FORMAT", + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + } +} + +// now only support ini, next will support json. +func parseConfig(appConfigPath string) (err error) { + AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) + if err != nil { + return err + } + return assignConfig(AppConfig) +} + +func assignConfig(ac config.Configer) error { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + // set the run mode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runMode := ac.String("RunMode"); runMode != "" { + BConfig.RunMode = runMode + } + + if sd := ac.String("StaticDir"); sd != "" { + BConfig.WebConfig.StaticDir = map[string]string{} + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + BConfig.WebConfig.StaticCacheFileSize = sfs + } + + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + BConfig.WebConfig.StaticCacheFileNum = sfn + } + + if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) + los := strings.Split(lo, ";") + for _, v := range los { + if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { + BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] + } else { + continue + } + } + } + + //init log + logs.Reset() + for adaptor, config := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, config) + if err != nil { + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + + return nil +} + +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(ac.DefaultInt64(name, pf.Int())) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + //do nothing here + } + } + +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + return err + } + + if !utils.FileExists(absConfigPath) { + return fmt.Errorf("the target config file: %s don't exist", configPath) + } + + appConfigPath = absConfigPath + appConfigProvider = adapterName + + return parseConfig(appConfigPath) +} + +type beegoAppConfig struct { + innerConfig config.Configer +} + +func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(appConfigProvider, appConfigPath) + if err != nil { + return nil, err + } + return &beegoAppConfig{ac}, nil +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { + return v + } + return b.innerConfig.String(key) +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { + return v + } + return b.innerConfig.Strings(key) +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int(key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int64(key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Bool(key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Float(key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 00000000..bfd79e85 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,242 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +//Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +//More docs http://beego.me/docs/module/config.md +package config + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error //support section::key type in given key when using ini type. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string //get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string //get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + if adapter == nil { + panic("config: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("config: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.Parse(filename) +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.ParseData(data) +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + switch value := v.(type) { + case string: + m[k] = ExpandValueEnv(value) + case map[string]interface{}: + m[k] = ExpandValueEnvForMap(value) + case map[string]string: + for k2, v2 := range value { + value[k2] = ExpandValueEnv(v2) + } + m[k] = value + } + } + return m +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) (realValue string) { + realValue = value + + vLen := len(value) + // 3 = ${} + if vLen < 3 { + return + } + // Need start with "${" and end with "}", then return. + if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { + return + } + + key := "" + defaultV := "" + // value start with "${" + for i := 2; i < vLen; i++ { + if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { + key = value[2:i] + defaultV = value[i+2 : vLen-1] // other string is default value. + break + } else if value[i] == '}' { + key = value[2:i] + break + } + } + + realValue = os.Getenv(key) + if realValue == "" { + realValue = defaultV + } + + return +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + if val != nil { + switch v := val.(type) { + case bool: + return v, nil + case string: + switch v { + case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": + return false, nil + } + case int8, int32, int64: + strV := fmt.Sprintf("%d", v) + if strV == "1" { + return true, nil + } else if strV == "0" { + return false, nil + } + case float64: + if v == 1.0 { + return true, nil + } else if v == 0.0 { + return false, nil + } + } + return false, fmt.Errorf("parsing %q: invalid syntax", val) + } + return false, fmt.Errorf("parsing : invalid syntax") +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go new file mode 100644 index 00000000..34f094fe --- /dev/null +++ b/pkg/config/env/env.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/utils" +) + +var env *utils.BeeMap + +func init() { + env = utils.NewBeeMap() + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Set(splits[0], os.Getenv(splits[0])) + } +} + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val := env.Get(key); val != nil { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val := env.Get(key); val != nil { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Set(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + items := env.Items() + envs := make(map[string]string, env.Count()) + + for key, val := range items { + switch key := key.(type) { + case string: + switch val := val.(type) { + case string: + envs[key] = val + } + } + } + return envs +} diff --git a/pkg/config/env/env_test.go b/pkg/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/config/fake.go b/pkg/config/fake.go new file mode 100644 index 00000000..d21ab820 --- /dev/null +++ b/pkg/config/fake.go @@ -0,0 +1,134 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "strconv" + "strings" +) + +type fakeConfigContainer struct { + data map[string]string +} + +func (c *fakeConfigContainer) getData(key string) string { + return c.data[strings.ToLower(key)] +} + +func (c *fakeConfigContainer) Set(key, val string) error { + c.data[strings.ToLower(key)] = val + return nil +} + +func (c *fakeConfigContainer) String(key string) string { + return c.getData(key) +} + +func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getData(key), 10, 64) +} + +func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getData(key), 64) +} + +func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return nil, errors.New("key not find") +} + +func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { + return nil, errors.New("not implement in the fakeConfigContainer") +} + +func (c *fakeConfigContainer) SaveConfigFile(filename string) error { + return errors.New("not implement in the fakeConfigContainer") +} + +var _ Configer = new(fakeConfigContainer) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + return &fakeConfigContainer{ + data: make(map[string]string), + } +} diff --git a/pkg/config/ini.go b/pkg/config/ini.go new file mode 100644 index 00000000..002e5e05 --- /dev/null +++ b/pkg/config/ini.go @@ -0,0 +1,504 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "io" + "io/ioutil" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" +) + +var ( + defaultSection = "default" // default section means if some ini items not in a section, make them in default section, + bNumComment = []byte{'#'} // number signal + bSemComment = []byte{';'} // semicolon signal + bEmpty = []byte{} + bEqual = []byte{'='} // equal signal + bDQuote = []byte{'"'} // quote signal + sectionStart = []byte{'['} // section start signal + sectionEnd = []byte{']'} // section end signal + lineBreak = "\n" +) + +// IniConfig implements Config to parse ini file. +type IniConfig struct { +} + +// Parse creates a new Config and parses the file configuration from the named file. +func (ini *IniConfig) Parse(name string) (Configer, error) { + return ini.parseFile(name) +} + +func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + + return ini.parseData(filepath.Dir(name), data) +} + +func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { + cfg := &IniConfigContainer{ + data: make(map[string]map[string]string), + sectionComment: make(map[string]string), + keyComment: make(map[string]string), + RWMutex: sync.RWMutex{}, + } + cfg.Lock() + defer cfg.Unlock() + + var comment bytes.Buffer + buf := bufio.NewReader(bytes.NewBuffer(data)) + // check the BOM + head, err := buf.Peek(3) + if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { + for i := 1; i <= 3; i++ { + buf.ReadByte() + } + } + section := defaultSection + tmpBuf := bytes.NewBuffer(nil) + for { + tmpBuf.Reset() + + shouldBreak := false + for { + tmp, isPrefix, err := buf.ReadLine() + if err == io.EOF { + shouldBreak = true + break + } + + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } + + tmpBuf.Write(tmp) + if isPrefix { + continue + } + + if !isPrefix { + break + } + } + if shouldBreak { + break + } + + line := tmpBuf.Bytes() + line = bytes.TrimSpace(line) + if bytes.Equal(line, bEmpty) { + continue + } + var bComment []byte + switch { + case bytes.HasPrefix(line, bNumComment): + bComment = bNumComment + case bytes.HasPrefix(line, bSemComment): + bComment = bSemComment + } + if bComment != nil { + line = bytes.TrimLeft(line, string(bComment)) + // Need append to a new line if multi-line comments. + if comment.Len() > 0 { + comment.WriteByte('\n') + } + comment.Write(line) + continue + } + + if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { + section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive + if comment.Len() > 0 { + cfg.sectionComment[section] = comment.String() + comment.Reset() + } + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + continue + } + + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + keyValue := bytes.SplitN(line, bEqual, 2) + + key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive + key = strings.ToLower(key) + + // handle include "other.conf" + if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + + includefiles := strings.Fields(key) + if includefiles[0] == "include" && len(includefiles) == 2 { + + otherfile := strings.Trim(includefiles[1], "\"") + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(dir, otherfile) + } + + i, err := ini.parseFile(otherfile) + if err != nil { + return nil, err + } + + for sec, dt := range i.data { + if _, ok := cfg.data[sec]; !ok { + cfg.data[sec] = make(map[string]string) + } + for k, v := range dt { + cfg.data[sec][k] = v + } + } + + for sec, comm := range i.sectionComment { + cfg.sectionComment[sec] = comm + } + + for k, comm := range i.keyComment { + cfg.keyComment[k] = comm + } + + continue + } + } + + if len(keyValue) != 2 { + return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") + } + val := bytes.TrimSpace(keyValue[1]) + if bytes.HasPrefix(val, bDQuote) { + val = bytes.Trim(val, `"`) + } + + cfg.data[section][key] = ExpandValueEnv(string(val)) + if comment.Len() > 0 { + cfg.keyComment[section+"."+key] = comment.String() + comment.Reset() + } + + } + return cfg, nil +} + +// ParseData parse ini the data +// When include other.conf,other.conf is either absolute directory +// or under beego in default temporary directory(/tmp/beego[-username]). +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, err + } + + return ini.parseData(dir, data) +} + +// IniConfigContainer A Config represents the ini configuration. +// When set and get value, support key as section:name type. +type IniConfigContainer struct { + data map[string]map[string]string // section=> key:val + sectionComment map[string]string // section : comment + keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *IniConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getdata(key)) +} + +// DefaultBool returns the boolean value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *IniConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getdata(key)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *IniConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getdata(key), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *IniConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getdata(key), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *IniConfigContainer) String(key string) string { + return c.getdata(key) +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *IniConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v, nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file. +// +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. +func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + // Get section or key comments. Fixed #1607 + getCommentStr := func(section, key string) string { + var ( + comment string + ok bool + ) + if len(key) == 0 { + comment, ok = c.sectionComment[section] + } else { + comment, ok = c.keyComment[section+"."+key] + } + + if ok { + // Empty comment + if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { + return string(bNumComment) + } + prefix := string(bNumComment) + // Add the line head character "#" + return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) + } + return "" + } + + buf := bytes.NewBuffer(nil) + // Save default section at first place + if dt, ok := c.data[defaultSection]; ok { + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(defaultSection, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + // Save named sections + for section, dt := range c.data { + if section != defaultSection { + // Write section comments. + if v := getCommentStr(section, ""); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write section name. + if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { + return err + } + + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(section, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + } + _, err = buf.WriteTo(f) + return err +} + +// Set writes a new value for key. +// if write to one section, the key need be "section::key". +// if the section is not existed, it panics. +func (c *IniConfigContainer) Set(key, value string) error { + c.Lock() + defer c.Unlock() + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + c.data[section][k] = value + return nil +} + +// DIY returns the raw value by a given key. +func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return v, errors.New("key not find") +} + +// section.key or key +func (c *IniConfigContainer) getdata(key string) string { + if len(key) == 0 { + return "" + } + c.RLock() + defer c.RUnlock() + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + if v, ok := c.data[section]; ok { + if vv, ok := v[k]; ok { + return vv + } + } + return "" +} + +func init() { + Register("ini", &IniConfig{}) +} diff --git a/pkg/config/ini_test.go b/pkg/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/config/json.go b/pkg/config/json.go new file mode 100644 index 00000000..c4ef25cd --- /dev/null +++ b/pkg/config/json.go @@ -0,0 +1,269 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" +) + +// JSONConfig is a json config parser and implements Config interface. +type JSONConfig struct { +} + +// Parse returns a ConfigContainer with parsed json config map. +func (js *JSONConfig) Parse(filename string) (Configer, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + content, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return js.ParseData(content) +} + +// ParseData returns a ConfigContainer with json string +func (js *JSONConfig) ParseData(data []byte) (Configer, error) { + x := &JSONConfigContainer{ + data: make(map[string]interface{}), + } + err := json.Unmarshal(data, &x.data) + if err != nil { + var wrappingArray []interface{} + err2 := json.Unmarshal(data, &wrappingArray) + if err2 != nil { + return nil, err + } + x.data["rootArray"] = wrappingArray + } + + x.data = ExpandValueEnvForMap(x.data) + + return x, nil +} + +// JSONConfigContainer A Config represents the json configuration. +// Only when get value, support key as section:name type. +type JSONConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *JSONConfigContainer) Bool(key string) (bool, error) { + val := c.getData(key) + if val != nil { + return ParseBool(val) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { + if v, err := c.Bool(key); err == nil { + return v + } + return defaultval +} + +// Int returns the integer value for a given key. +func (c *JSONConfigContainer) Int(key string) (int, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int(v), nil + } else if v, ok := val.(string); ok { + return strconv.Atoi(v) + } + return 0, errors.New("not valid value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { + if v, err := c.Int(key); err == nil { + return v + } + return defaultval +} + +// Int64 returns the int64 value for a given key. +func (c *JSONConfigContainer) Int64(key string) (int64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int64(v), nil + } + return 0, errors.New("not int64 value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + if v, err := c.Int64(key); err == nil { + return v + } + return defaultval +} + +// Float returns the float value for a given key. +func (c *JSONConfigContainer) Float(key string) (float64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return v, nil + } + return 0.0, errors.New("not float64 value") + } + return 0.0, errors.New("not exist key:" + key) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + if v, err := c.Float(key); err == nil { + return v + } + return defaultval +} + +// String returns the string value for a given key. +func (c *JSONConfigContainer) String(key string) string { + val := c.getData(key) + if val != nil { + if v, ok := val.(string); ok { + return v + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { + // TODO FIXME should not use "" to replace non existence + if v := c.String(key); v != "" { + return v + } + return defaultval +} + +// Strings returns the []string value for a given key. +func (c *JSONConfigContainer) Strings(key string) []string { + stringVal := c.String(key) + if stringVal == "" { + return nil + } + return strings.Split(c.String(key), ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { + if v := c.Strings(key); v != nil { + return v + } + return defaultval +} + +// GetSection returns map for the given section +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("nonexist section " + section) +} + +// SaveConfigFile save the config into file +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := json.MarshalIndent(c.data, "", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *JSONConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { + val := c.getData(key) + if val != nil { + return val, nil + } + return nil, errors.New("not exist key") +} + +// section.key or key +func (c *JSONConfigContainer) getData(key string) interface{} { + if len(key) == 0 { + return nil + } + + c.RLock() + defer c.RUnlock() + + sectionKeys := strings.Split(key, "::") + if len(sectionKeys) >= 2 { + curValue, ok := c.data[sectionKeys[0]] + if !ok { + return nil + } + for _, key := range sectionKeys[1:] { + if v, ok := curValue.(map[string]interface{}); ok { + if curValue, ok = v[key]; !ok { + return nil + } + } + } + return curValue + } + if v, ok := c.data[key]; ok { + return v + } + return nil +} + +func init() { + Register("json", &JSONConfig{}) +} diff --git a/pkg/config/json_test.go b/pkg/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go new file mode 100644 index 00000000..494242d3 --- /dev/null +++ b/pkg/config/xml/xml.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +//More docs http://beego.me/docs/module/config.md +package xml + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/x2j" +) + +// Config is a xml config parser and implements Config interface. +// xml configurations should be included in tag. +// only support key/value pair as value as each item. +type Config struct{} + +// Parse returns a ConfigContainer with parsed xml config map. +func (xc *Config) Parse(filename string) (config.Configer, error) { + context, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + return xc.ParseData(context) +} + +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { + x := &ConfigContainer{data: make(map[string]interface{})} + + d, err := x2j.DocToMap(string(data)) + if err != nil { + return nil, err + } + + x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + + return x, nil +} + +// ConfigContainer A Config represents the xml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.Mutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + if v := c.data[key]; v != nil { + return config.ParseBool(v) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.data[key].(string)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.data[key].(string), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v + +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.data[key].(string), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, ok := c.data[key].(string); ok { + return v + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil + } + return nil, fmt.Errorf("section '%s' not found", section) +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := xml.MarshalIndent(c.data, " ", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[key]; ok { + return v, nil + } + return nil, errors.New("not exist key") +} + +func init() { + config.Register("xml", &Config{}) +} diff --git a/pkg/config/xml/xml_test.go b/pkg/config/xml/xml_test.go new file mode 100644 index 00000000..346c866e --- /dev/null +++ b/pkg/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go new file mode 100644 index 00000000..5def2da3 --- /dev/null +++ b/pkg/config/yaml/yaml.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +//More docs http://beego.me/docs/module/config.md +package yaml + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/goyaml2" +) + +// Config is a yaml config parser and implements Config interface. +type Config struct{} + +// Parse returns a ConfigContainer with parsed yaml config map. +func (yaml *Config) Parse(filename string) (y config.Configer, err error) { + cnf, err := ReadYmlReader(filename) + if err != nil { + return + } + y = &ConfigContainer{ + data: cnf, + } + return +} + +// ParseData parse yaml data +func (yaml *Config) ParseData(data []byte) (config.Configer, error) { + cnf, err := parseYML(data) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: cnf, + }, nil +} + +// ReadYmlReader Read yaml file to map. +// if json like, use json package, unless goyaml2 package. +func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { + buf, err := ioutil.ReadFile(path) + if err != nil { + return + } + + return parseYML(buf) +} + +// parseYML parse yaml formatted []byte to map. +func parseYML(buf []byte) (cnf map[string]interface{}, err error) { + if len(buf) < 3 { + return + } + + if string(buf[0:1]) == "{" { + log.Println("Look like a Json, try json umarshal") + err = json.Unmarshal(buf, &cnf) + if err == nil { + log.Println("It is Json Map") + return + } + } + + data, err := goyaml2.Read(bytes.NewReader(buf)) + if err != nil { + log.Println("Goyaml2 ERR>", string(buf), err) + return + } + + if data == nil { + log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) + return + } + cnf, ok := data.(map[string]interface{}) + if !ok { + log.Println("Not a Map? >> ", string(buf), data) + cnf = nil + } + cnf = config.ExpandValueEnvForMap(cnf) + return +} + +// ConfigContainer A Config represents the yaml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + v, err := c.getData(key) + if err != nil { + return false, err + } + return config.ParseBool(v) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int); ok { + return vv, nil + } else if vv, ok := v.(int64); ok { + return int(vv), nil + } + return 0, errors.New("not int value") +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int64); ok { + return vv, nil + } + return 0, errors.New("not bool value") +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + if v, err := c.getData(key); err != nil { + return 0.0, err + } else if vv, ok := v.(float64); ok { + return vv, nil + } else if vv, ok := v.(int); ok { + return float64(vv), nil + } else if vv, ok := v.(int64); ok { + return float64(vv), nil + } + return 0.0, errors.New("not float64 value") +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, err := c.getData(key); err == nil { + if vv, ok := v.(string); ok { + return vv + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + err = goyaml2.Write(f, c.data) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + return c.getData(key) +} + +func (c *ConfigContainer) getData(key string) (interface{}, error) { + + if len(key) == 0 { + return nil, errors.New("key is empty") + } + c.RLock() + defer c.RUnlock() + + keys := strings.Split(key, ".") + tmpData := c.data + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys) - 1 { + return tmpData, nil + } + } + default: + { + return v, nil + } + + } + } + } + return nil, fmt.Errorf("not exist key %q", key) +} + +func init() { + config.Register("yaml", &Config{}) +} diff --git a/pkg/config/yaml/yaml_test.go b/pkg/config/yaml/yaml_test.go new file mode 100644 index 00000000..49cc1d1e --- /dev/null +++ b/pkg/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} diff --git a/pkg/config_test.go b/pkg/config_test.go new file mode 100644 index 00000000..5f71f1c3 --- /dev/null +++ b/pkg/config_test.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestDefaults(t *testing.T) { + if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { + t.Errorf("FlashName was not set to default.") + } + + if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { + t.Errorf("FlashName was not set to default.") + } +} + +func TestAssignConfig_01(t *testing.T) { + _BConfig := &Config{} + _BConfig.AppName = "beego_test" + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) + assignSingleConfig(_BConfig, ac) + if _BConfig.AppName != "beego_json" { + t.Log(_BConfig) + t.FailNow() + } +} + +func TestAssignConfig_02(t *testing.T) { + _BConfig := &Config{} + bs, _ := json.Marshal(newBConfig()) + + jsonMap := M{} + json.Unmarshal(bs, &jsonMap) + + configMap := M{} + for k, v := range jsonMap { + if reflect.TypeOf(v).Kind() == reflect.Map { + for k1, v1 := range v.(M) { + if reflect.TypeOf(v1).Kind() == reflect.Map { + for k2, v2 := range v1.(M) { + configMap[k2] = v2 + } + } else { + configMap[k1] = v1 + } + } + } else { + configMap[k] = v + } + } + configMap["MaxMemory"] = 1024 + configMap["Graceful"] = true + configMap["XSRFExpire"] = 32 + configMap["SessionProviderConfig"] = "file" + configMap["FileLineNum"] = true + + jcf := &config.JSONConfig{} + bs, _ = json.Marshal(configMap) + ac, _ := jcf.ParseData(bs) + + for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + if _BConfig.MaxMemory != 1024 { + t.Log(_BConfig.MaxMemory) + t.FailNow() + } + + if !_BConfig.Listen.Graceful { + t.Log(_BConfig.Listen.Graceful) + t.FailNow() + } + + if _BConfig.WebConfig.XSRFExpire != 32 { + t.Log(_BConfig.WebConfig.XSRFExpire) + t.FailNow() + } + + if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { + t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) + t.FailNow() + } + + if !_BConfig.Log.FileLineNum { + t.Log(_BConfig.Log.FileLineNum) + t.FailNow() + } + +} + +func TestAssignConfig_03(t *testing.T) { + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") + assignConfig(ac) + + t.Logf("%#v", BConfig) + + if BConfig.AppName != "test_app" { + t.FailNow() + } + + if BConfig.RunMode != "online" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download"] != "down" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download2"] != "down2" { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileSize != 87456 { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileNum != 1254 { + t.FailNow() + } + if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { + t.FailNow() + } +} diff --git a/pkg/context/acceptencoder.go b/pkg/context/acceptencoder.go new file mode 100644 index 00000000..b4e2492c --- /dev/null +++ b/pkg/context/acceptencoder.go @@ -0,0 +1,232 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + +type resetWriter interface { + io.Writer + Reset(w io.Writer) +} + +type nopResetWriter struct { + io.Writer +} + +func (n nopResetWriter) Reset(w io.Writer) { + //do nothing +} + +type acceptEncoder struct { + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool +} + +func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return nopResetWriter{wr} + } + var rwr resetWriter + switch level { + case flate.BestSpeed: + rwr = ac.customCompressLevelPool.Get().(resetWriter) + case flate.BestCompression: + rwr = ac.bestCompressionPool.Get().(resetWriter) + default: + rwr = ac.levelEncode(level) + } + rwr.Reset(wr) + return rwr +} + +func (ac acceptEncoder) put(wr resetWriter, level int) { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return + } + wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + + switch level { + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) + case flate.BestCompression: + ac.bestCompressionPool.Put(wr) + } +} + +var ( + noneCompressEncoder = acceptEncoder{"", nil, nil, nil} + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } + + //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed + //deflate + //The "zlib" format defined in RFC 1950 [31] in combination with + //the "deflate" compression mechanism described in RFC 1951 [29]. + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter resetWriter + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter = ce.encode(writer, level) + defer ce.put(outputWriter, level) + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + var cf acceptEncoder + var ok bool + if cf, ok = encoderMap[vs[0]]; !ok { + continue + } + if len(vs) == 1 { + return cf.name + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{cf.name, f} + } + } + } + return lastQ.name +} diff --git a/pkg/context/acceptencoder_test.go b/pkg/context/acceptencoder_test.go new file mode 100644 index 00000000..e3d61e27 --- /dev/null +++ b/pkg/context/acceptencoder_test.go @@ -0,0 +1,59 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { + t.Fail() + } +} diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 00000000..de248ed2 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,263 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/utils" +) + +//commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context struct { + Input *BeegoInput + Output *BeegoOutput + Request *http.Request + ResponseWriter *Response + _xsrfToken string +} + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + if ctx.ResponseWriter == nil { + ctx.ResponseWriter = &Response{} + } + ctx.ResponseWriter.reset(rw) + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) + ctx._xsrfToken = "" +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) + panic(body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + ctx.ResponseWriter.Write([]byte(content)) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return ctx.Input.Cookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + ctx.Output.Cookie(name, value, others...) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { + token, ok := ctx.GetSecureCookie(key, "_xsrf") + if !ok { + token = string(utils.RandomCreateBytes(32)) + ctx.SetSecureCookie(key, "_xsrf", token, expire) + } + ctx._xsrfToken = token + } + return ctx._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + token := ctx.Input.Query("_xsrf") + if token == "" { + token = ctx.Request.Header.Get("X-Xsrftoken") + } + if token == "" { + token = ctx.Request.Header.Get("X-Csrftoken") + } + if token == "" { + ctx.Abort(422, "422") + return false + } + if ctx._xsrfToken != token { + ctx.Abort(417, "417") + return false + } + return true +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int + Elapsed time.Duration +} + +func (r *Response) reset(rw http.ResponseWriter) { + r.ResponseWriter = rw + r.Status = 0 + r.Started = false +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { + //prevent multiple response.WriteHeader calls + return + } + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/pkg/context/context_test.go b/pkg/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/pkg/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/pkg/context/input.go b/pkg/context/input.go new file mode 100644 index 00000000..385549c1 --- /dev/null +++ b/pkg/context/input.go @@ -0,0 +1,689 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" +) + +// Regexes for checking the accept headers +// TODO make sure these are correct +var ( + acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) + acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) + acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) + maxParam = 50 +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput struct { + Context *Context + CruSession session.Store + pnames []string + pvalues []string + data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex + RequestBody []byte + RunMethod string + RunController reflect.Type +} + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return &BeegoInput{ + pnames: make([]string, 0, maxParam), + pvalues: make([]string, 0, maxParam), + data: make(map[interface{}]interface{}), + } +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + input.Context = ctx + input.CruSession = nil + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] + input.dataLock.Lock() + input.data = nil + input.dataLock.Unlock() + input.RequestBody = []byte{} +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return input.Context.Request.Proto +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return input.Context.Request.URL.EscapedPath() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return input.Scheme() + "://" + input.Domain() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if input.Context.Request.URL.Scheme != "" { + return input.Context.Request.URL.Scheme + } + if input.Context.Request.TLS == nil { + return "http" + } + return "https" +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return input.Host() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + if input.Context.Request.Host != "" { + if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + return hostPart + } + return input.Context.Request.Host + } + return "localhost" +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return input.Context.Request.Method +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return input.Method() == method +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return input.Is("GET") +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return input.Is("POST") +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return input.Is("HEAD") +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return input.Is("OPTIONS") +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return input.Is("PUT") +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return input.Is("DELETE") +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return input.Is("PATCH") +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return input.Header("X-Requested-With") == "XMLHttpRequest" +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return input.Scheme() == "https" +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return input.Header("Upgrade") == "websocket" +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return strings.Contains(input.Header("Content-Type"), "multipart/form-data") +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return acceptsHTMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return acceptsXMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return acceptsJSONRegex.MatchString(input.Header("Accept")) +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return acceptsYAMLRegex.MatchString(input.Header("Accept")) +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + ips := input.Proxy() + if len(ips) > 0 && ips[0] != "" { + rip, _, err := net.SplitHostPort(ips[0]) + if err != nil { + rip = ips[0] + } + return rip + } + if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { + return ip + } + return input.Context.Request.RemoteAddr +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + if ips := input.Header("X-Forwarded-For"); ips != "" { + return strings.Split(ips, ",") + } + return []string{} +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return input.Header("Referer") +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return input.Referer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + parts := strings.Split(input.Host(), ".") + if len(parts) >= 3 { + return strings.Join(parts[:len(parts)-2], ".") + } + return "" +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + port, _ := strconv.Atoi(portPart) + return port + } + return 80 +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return input.Header("User-Agent") +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return len(input.pnames) +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + // we cannot use url.PathEscape(input.pvalues[i]) + // for example, if the value is /a/b + // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb + // However, the value is used in ControllerRegister.ServeHTTP + // and split by "/", so function crash... + return input.pvalues[i] + } + } + return "" +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + m := make(map[string]string) + for i, v := range input.pnames { + if i <= len(input.pvalues) { + m[v] = input.pvalues[i] + } + } + return m +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + // check if already exists + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + input.pvalues[i] = val + return + } + } + input.pvalues = append(input.pvalues, val) + input.pnames = append(input.pnames, key) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + if val := input.Param(key); val != "" { + return val + } + if input.Context.Request.Form == nil { + input.dataLock.Lock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + input.dataLock.Unlock() + } + input.dataLock.RLock() + defer input.dataLock.RUnlock() + return input.Context.Request.Form.Get(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return input.Context.Request.Header.Get(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + ck, err := input.Context.Request.Cookie(key) + if err != nil { + return "" + } + return ck.Value +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return input.CruSession.Get(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + if input.Context.Request.Body == nil { + return []byte{} + } + + var requestbody []byte + safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} + if input.Header("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(safe) + if err != nil { + return nil + } + requestbody, _ = ioutil.ReadAll(reader) + } else { + requestbody, _ = ioutil.ReadAll(safe) + } + + input.Context.Request.Body.Close() + bf := bytes.NewBuffer(requestbody) + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) + input.RequestBody = requestbody + return requestbody +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + return input.data +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if v, ok := input.data[key]; ok { + return v + } + return nil +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + input.data[key] = val +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + // Parse the body depending on the content type. + if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + } else if err := input.Context.Request.ParseForm(); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + return nil +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { + return errors.New("beego: non-pointer passed to Bind: " + key) + } + value = value.Elem() + if !value.CanSet() { + return errors.New("beego: non-settable variable passed to Bind: " + key) + } + typ := value.Type() + // Get real type if dest define with interface{}. + // e.g var dest interface{} dest=1.0 + if value.Kind() == reflect.Interface { + typ = value.Elem().Type() + } + rv := input.bind(key, typ) + if !rv.IsValid() { + return errors.New("beego: reflect value is empty") + } + value.Set(rv) + return nil +} + +func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindFloat(val, typ) + case reflect.String: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindString(val, typ) + case reflect.Bool: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&input.Context.Request.Form, key, typ) + case reflect.Struct: + rv = input.bindStruct(&input.Context.Request.Form, key, typ) + case reflect.Ptr: + rv = input.bindPoint(key, typ) + case reflect.Map: + rv = input.bindMap(&input.Context.Request.Form, key, typ) + } + return rv +} + +func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + rv = input.bindFloat(val, typ) + case reflect.String: + rv = input.bindString(val, typ) + case reflect.Bool: + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&url.Values{"": {val}}, "", typ) + case reflect.Struct: + rv = input.bindStruct(&url.Values{"": {val}}, "", typ) + case reflect.Ptr: + rv = input.bindPoint(val, typ) + case reflect.Map: + rv = input.bindMap(&url.Values{"": {val}}, "", typ) + } + return rv +} + +func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { + intValue, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetInt(intValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { + uintValue, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetUint(uintValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { + floatValue, err := strconv.ParseFloat(val, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetFloat(floatValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { + return reflect.ValueOf(val) +} + +func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "true", "on", "1": + return reflect.ValueOf(true) + } + return reflect.ValueOf(false) +} + +type sliceValue struct { + index int // Index extracted from brackets. If -1, no index was provided. + value reflect.Value // the bound value for this slice element. +} + +func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { + maxIndex := -1 + numNoIndex := 0 + sliceValues := []sliceValue{} + for reqKey, vals := range *params { + if !strings.HasPrefix(reqKey, key+"[") { + continue + } + // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) + index := -1 + leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) + if rightBracket > leftBracket+1 { + index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) + } + subKeyIndex := rightBracket + 1 + + // Handle the indexed case. + if index > -1 { + if index > maxIndex { + maxIndex = index + } + sliceValues = append(sliceValues, sliceValue{ + index: index, + value: input.bind(reqKey[:subKeyIndex], typ.Elem()), + }) + continue + } + + // It's an un-indexed element. (e.g. element[]) + numNoIndex += len(vals) + for _, val := range vals { + // Unindexed values can only be direct-bound. + sliceValues = append(sliceValues, sliceValue{ + index: -1, + value: input.bindValue(val, typ.Elem()), + }) + } + } + resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) + for _, sv := range sliceValues { + if sv.index != -1 { + resultArray.Index(sv.index).Set(sv.value) + } else { + resultArray = reflect.Append(resultArray, sv.value) + } + } + return resultArray +} + +func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { + result := reflect.New(typ).Elem() + fieldValues := make(map[string]reflect.Value) + for reqKey, val := range *params { + var fieldName string + if strings.HasPrefix(reqKey, key+".") { + fieldName = reqKey[len(key)+1:] + } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { + fieldName = reqKey[len(key)+1 : len(reqKey)-1] + } else { + continue + } + + if _, ok := fieldValues[fieldName]; !ok { + // Time to bind this field. Get it and make sure we can set it. + fieldValue := result.FieldByName(fieldName) + if !fieldValue.IsValid() { + continue + } + if !fieldValue.CanSet() { + continue + } + boundVal := input.bindValue(val[0], fieldValue.Type()) + fieldValue.Set(boundVal) + fieldValues[fieldName] = boundVal + } + } + + return result +} + +func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { + return input.bind(key, typ.Elem()).Addr() +} + +func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { + var ( + result = reflect.MakeMap(typ) + keyType = typ.Key() + valueType = typ.Elem() + ) + for paramName, values := range *params { + if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { + continue + } + + key := paramName[len(key)+1 : len(paramName)-1] + result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) + } + return result +} diff --git a/pkg/context/input_test.go b/pkg/context/input_test.go new file mode 100644 index 00000000..3a6c2e7b --- /dev/null +++ b/pkg/context/input_test.go @@ -0,0 +1,217 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestBind(t *testing.T) { + type testItem struct { + field string + empty interface{} + want interface{} + } + type Human struct { + ID int + Nick string + Pwd string + Ms bool + } + + cases := []struct { + request string + valueGp []testItem + }{ + {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, + + {"/?p=", []testItem{{"p", "", ""}}}, + {"/?p=str", []testItem{{"p", "", "str"}}}, + + {"/?p=123", []testItem{{"p", 0, 123}}}, + {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, + + {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + + {"/?p=true", []testItem{{"p", false, true}}}, + {"/?p=ON", []testItem{{"p", false, true}}}, + {"/?p=on", []testItem{{"p", false, true}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + {"/?p=2", []testItem{{"p", false, false}}}, + {"/?p=false", []testItem{{"p", false, false}}}, + + {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, + {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, + + {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, + {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, + + {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, + + {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, + + {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, + {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, + {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + []testItem{{"human", []Human{}, []Human{ + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + }}}}, + + { + "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", + []testItem{ + {"id", 0, 123}, + {"isok", false, true}, + {"ft", 0.0, 1.2}, + {"ol", []int{}, []int{1, 2}}, + {"ul", []string{}, []string{"str", "array"}}, + {"human", Human{}, Human{Nick: "astaxie"}}, + }, + }, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + for _, item := range c.valueGp { + got := item.empty + err := beegoInput.Bind(&got, item.field) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, item.want) { + t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) + } + } + + } +} + +func TestSubDomain(t *testing.T) { + r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + subdomain := beegoInput.SubDomains() + if subdomain != "www" { + t.Fatal("Subdomain parse error, got" + subdomain) + } + + r, _ = http.NewRequest("GET", "http://localhost/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + /* TODO Fix this + r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + */ + + r, _ = http.NewRequest("GET", "http://example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb.cc.dd" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } +} + +func TestParams(t *testing.T) { + inp := NewInput() + + inp.SetParam("p1", "val1_ver1") + inp.SetParam("p2", "val2_ver1") + inp.SetParam("p3", "val3_ver1") + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") + } + if val := inp.Param("p3"); val != "val3_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") + } + vals := inp.Params() + expected := map[string]string{ + "p1": "val1_ver1", + "p2": "val2_ver1", + "p3": "val3_ver1", + } + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + // overwriting existing params + inp.SetParam("p1", "val1_ver2") + inp.SetParam("p2", "val2_ver2") + expected = map[string]string{ + "p1": "val1_ver2", + "p2": "val2_ver2", + "p3": "val3_ver1", + } + vals = inp.Params() + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + + if val := inp.Param("p2"); val != "val2_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + +} +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} diff --git a/pkg/context/output.go b/pkg/context/output.go new file mode 100644 index 00000000..238dcf45 --- /dev/null +++ b/pkg/context/output.go @@ -0,0 +1,408 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "html/template" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + yaml "gopkg.in/yaml.v2" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput struct { + Context *Context + Status int + EnableGzip bool +} + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return &BeegoOutput{} +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + output.Context = ctx + output.Status = 0 +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + output.Context.ResponseWriter.Header().Set(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) + output.Header("Content-Length", strconv.Itoa(buf.Len())) + } else { + output.Header("Content-Length", strconv.Itoa(len(content))) + } + // Write status code if it has been set manually + // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" + if output.Status != 0 { + output.Context.ResponseWriter.WriteHeader(output.Status) + output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true + } + io.Copy(output.Context.ResponseWriter, buf) + return nil +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) + + //fix cookie not work in IE + if len(others) > 0 { + var maxAge int64 + + switch v := others[0].(type) { + case int: + maxAge = int64(v) + case int32: + maxAge = int64(v) + case int64: + maxAge = v + } + + switch { + case maxAge > 0: + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + case maxAge < 0: + fmt.Fprintf(&b, "; Max-Age=0") + } + } + + // the settings below + // Path, Domain, Secure, HttpOnly + // can use nil skip set + + // default "/" + if len(others) > 1 { + if v, ok := others[1].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + } + } else { + fmt.Fprintf(&b, "; Path=%s", "/") + } + + // default empty + if len(others) > 2 { + if v, ok := others[2].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) + } + } + + // default empty + if len(others) > 3 { + var secure bool + switch v := others[3].(type) { + case bool: + secure = v + default: + if others[3] != nil { + secure = true + } + } + if secure { + fmt.Fprintf(&b, "; Secure") + } + } + + // default false. for session cookie default true + if len(others) > 4 { + if v, ok := others[4].(bool); ok && v { + fmt.Fprintf(&b, "; HttpOnly") + } + } + + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + +func sanitizeValue(v string) string { + return cookieValueSanitizer.Replace(v) +} + +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.Output.Body([]byte(err.Error())) + }) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + output.Header("Content-Type", "application/json; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + if encoding { + content = []byte(stringsToJSON(string(content))) + } + return output.Body(content) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + output.Header("Content-Type", "application/x-yaml; charset=utf-8") + var content []byte + var err error + content, err = yaml.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/javascript; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + callback := output.Context.Input.Query("callback") + if callback == "" { + return errors.New(`"callback" parameter required`) + } + callback = template.JSEscapeString(callback) + callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) + callbackContent.WriteString("(") + callbackContent.Write(content) + callbackContent.WriteString(");\r\n") + return output.Body(callbackContent.Bytes()) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/xml; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = xml.MarshalIndent(data, "", " ") + } else { + content, err = xml.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + accept := output.Context.Input.Header("Accept") + switch accept { + case ApplicationYAML: + output.YAML(data) + case ApplicationXML, TextXML: + output.XML(data, hasIndent) + default: + output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + } +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + + var fName string + if len(filename) > 0 && filename[0] != "" { + fName = filename[0] + } else { + fName = filepath.Base(file) + } + //https://tools.ietf.org/html/rfc6266#section-4.3 + fn := url.PathEscape(fName) + if fName == fn { + fn = "filename=" + fn + } else { + /** + The parameters "filename" and "filename*" differ only in that + "filename*" uses the encoding defined in [RFC5987], allowing the use + of characters not present in the ISO-8859-1 character set + ([ISO-8859-1]). + */ + fn = "filename=" + fName + "; filename*=utf-8''" + fn + } + output.Header("Content-Disposition", "attachment; "+fn) + output.Header("Content-Description", "File Transfer") + output.Header("Content-Type", "application/octet-stream") + output.Header("Content-Transfer-Encoding", "binary") + output.Header("Expires", "0") + output.Header("Cache-Control", "must-revalidate") + output.Header("Pragma", "public") + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + ctype := mime.TypeByExtension(ext) + if ctype != "" { + output.Header("Content-Type", ctype) + } +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + output.Status = status +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return output.Status >= 200 && output.Status < 300 || output.Status == 304 +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return output.Status == 201 || output.Status == 204 || output.Status == 304 +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return output.Status == 200 +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return output.Status >= 200 && output.Status < 300 +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return output.Status == 403 +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return output.Status == 404 +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return output.Status >= 400 && output.Status < 500 +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return output.Status >= 500 && output.Status < 600 +} + +func stringsToJSON(str string) string { + var jsons bytes.Buffer + for _, r := range str { + rint := int(r) + if rint < 128 { + jsons.WriteRune(r) + } else { + jsons.WriteString("\\u") + if rint < 0x100 { + jsons.WriteString("00") + } else if rint < 0x1000 { + jsons.WriteString("0") + } + jsons.WriteString(strconv.FormatInt(int64(rint), 16)) + } + } + return jsons.String() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + output.Context.Input.CruSession.Set(name, value) +} diff --git a/pkg/context/param/conv.go b/pkg/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/pkg/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/pkg/context/param/methodparams.go b/pkg/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/pkg/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/pkg/context/param/options.go b/pkg/context/param/options.go new file mode 100644 index 00000000..3d5ba013 --- /dev/null +++ b/pkg/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/pkg/context/param/parsers.go b/pkg/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/pkg/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/pkg/context/param/parsers_test.go b/pkg/context/param/parsers_test.go new file mode 100644 index 00000000..7065a28e --- /dev/null +++ b/pkg/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/pkg/context/renderer.go b/pkg/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/pkg/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/pkg/context/response.go b/pkg/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/pkg/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/controller.go b/pkg/controller.go new file mode 100644 index 00000000..0e8853b3 --- /dev/null +++ b/pkg/controller.go @@ -0,0 +1,706 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "fmt" + "html/template" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/session" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = errors.New("user stop run") + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = make(map[string][]ControllerComments) +) + +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + +// ControllerComments store the comment for the controller method +type ControllerComments struct { + Method string + Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments + AllowHTTPMethods []string + Params []map[string]string + MethodParams []*param.MethodParam +} + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice []ControllerComments + +func (p ControllerCommentsSlice) Len() int { return len(p) } +func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } +func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller struct { + // context data + Ctx *context.Context + Data map[interface{}]interface{} + + // route controller info + controllerName string + actionName string + methodMapping map[string]func() //method:routertree + AppController interface{} + + // template data + TplName string + ViewPath string + Layout string + LayoutSections map[string]string // the key is the section name and the value is the template name + TplPrefix string + TplExt string + EnableRender bool + + // xsrf data + _xsrfToken string + XSRFExpire int + EnableXSRF bool + + // session + CruSession session.Store +} + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface interface { + Init(ct *context.Context, controllerName, actionName string, app interface{}) + Prepare() + Get() + Post() + Delete() + Put() + Head() + Patch() + Options() + Trace() + Finish() + Render() error + XSRFToken() string + CheckXSRFCookie() bool + HandlerFunc(fn string) bool + URLMapping() +} + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + c.Layout = "" + c.TplName = "" + c.controllerName = controllerName + c.actionName = actionName + c.Ctx = ctx + c.TplExt = "tpl" + c.AppController = app + c.EnableRender = true + c.EnableXSRF = true + c.Data = ctx.Input.Data() + c.methodMapping = make(map[string]func()) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() {} + +// Finish runs after request function execution. +func (c *Controller) Finish() {} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + ts := func(h http.Header) (hs string) { + for k, v := range h { + hs += fmt.Sprintf("\r\n%s: %s", k, v) + } + return + } + hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) + c.Ctx.Output.Header("Content-Type", "message/http") + c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) + c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Ctx.WriteString(hs) +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + if v, ok := c.methodMapping[fnname]; ok { + v() + return true + } + return false +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() {} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + c.methodMapping[method] = fn +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + if !c.EnableRender { + return nil + } + rb, err := c.RenderBytes() + if err != nil { + return err + } + + if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + } + + return c.Ctx.Output.Body(rb) +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + b, e := c.RenderBytes() + return string(b), e +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + buf, err := c.renderTemplate() + //if the controller has set layout, then first get the tplName's content set the content to the layout + if err == nil && c.Layout != "" { + c.Data["LayoutContent"] = template.HTML(buf.String()) + + if c.LayoutSections != nil { + for sectionName, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + c.Data[sectionName] = "" + continue + } + buf.Reset() + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) + if err != nil { + return nil, err + } + c.Data[sectionName] = template.HTML(buf.String()) + } + } + + buf.Reset() + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) + } + return buf.Bytes(), err +} + +func (c *Controller) renderTemplate() (bytes.Buffer, error) { + var buf bytes.Buffer + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + } + if c.TplPrefix != "" { + c.TplName = c.TplPrefix + c.TplName + } + if BConfig.RunMode == DEV { + buildFiles := []string{c.TplName} + if c.Layout != "" { + buildFiles = append(buildFiles, c.Layout) + if c.LayoutSections != nil { + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue + } + buildFiles = append(buildFiles, sectionTpl) + } + } + } + BuildTemplate(c.viewPath(), buildFiles...) + } + return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) +} + +func (c *Controller) viewPath() string { + if c.ViewPath == "" { + return BConfig.WebConfig.ViewsPath + } + return c.ViewPath +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + LogAccess(c.Ctx, nil, code) + c.Ctx.Redirect(code, url) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case context.ApplicationYAML: + c.Data["yaml"] = data + case context.ApplicationXML, context.TextXML: + c.Data["xml"] = data + default: + c.Data["json"] = data + } +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + status, err := strconv.Atoi(code) + if err != nil { + status = 200 + } + c.CustomAbort(status, code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + // first panic from ErrorMaps, it is user defined error functions. + if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status + panic(body) + } + // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) + c.Ctx.ResponseWriter.Write([]byte(body)) + panic(ErrAbort) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + panic(ErrAbort) +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + if len(endpoint) == 0 { + return "" + } + if endpoint[0] == '.' { + return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) + } + return URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + var ( + hasIndent = BConfig.RunMode != PROD + hasEncoding = len(encoding) > 0 && encoding[0] + ) + + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.XML(c.Data["xml"], hasIndent) +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + c.Ctx.Output.YAML(c.Data["yaml"]) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + hasIndent := BConfig.RunMode != PROD + hasEncoding := len(encoding) > 0 && encoding[0] + c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + if c.Ctx.Request.Form == nil { + c.Ctx.Request.ParseForm() + } + return c.Ctx.Request.Form +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return ParseForm(c.Input(), obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + if v := c.Ctx.Input.Query(key); v != "" { + return v + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + var defv []string + if len(def) > 0 { + defv = def[0] + } + + if f := c.Input(); f == nil { + return defv + } else if vs := f[key]; len(vs) > 0 { + return vs + } + + return defv +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.Atoi(strv) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 8) + return int8(i64), err +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 16) + return int16(i64), err +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 32) + return int32(i64), err +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseInt(strv, 10, 64) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseBool(strv) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseFloat(strv, 64) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return c.Ctx.Request.FormFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { + return files, nil + } + return nil, http.ErrMissingFile +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + file, _, err := c.Ctx.Request.FormFile(fromfile) + if err != nil { + return err + } + defer file.Close() + f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer f.Close() + io.Copy(f, file) + return nil +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + if c.CruSession == nil { + c.CruSession = c.Ctx.Input.CruSession + } + return c.CruSession +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Set(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + if c.CruSession == nil { + c.StartSession() + } + return c.CruSession.Get(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Delete(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + if c.CruSession != nil { + c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + } + c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) + c.Ctx.Input.CruSession = c.CruSession +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession = nil + GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return c.Ctx.Input.IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return c.Ctx.GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + c.Ctx.SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + if c._xsrfToken == "" { + expire := int64(BConfig.WebConfig.XSRFExpire) + if c.XSRFExpire > 0 { + expire = int64(c.XSRFExpire) + } + c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) + } + return c._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + if !c.EnableXSRF { + return true + } + return c.Ctx.CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return `` +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return c.controllerName, c.actionName +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go new file mode 100644 index 00000000..1e53416d --- /dev/null +++ b/pkg/controller_test.go @@ -0,0 +1,181 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "math" + "strconv" + "testing" + + "github.com/astaxie/beego/context" + "os" + "path/filepath" +) + +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt("age") + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } +} + +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt8("age") + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } + //Output: int8 +} + +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt16("age") + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt32("age") + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt64("age") + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } +} + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + dir1 := "_beeTmp" + dir2 := "_beeTmp2" + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} diff --git a/pkg/doc.go b/pkg/doc.go new file mode 100644 index 00000000..8825bd29 --- /dev/null +++ b/pkg/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego" + + func main() { + beego.Run() + } + +more information: http://beego.me +*/ +package beego diff --git a/pkg/error.go b/pkg/error.go new file mode 100644 index 00000000..f268f723 --- /dev/null +++ b/pkg/error.go @@ -0,0 +1,488 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "html/template" + "net/http" + "reflect" + "runtime" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +// render default application error page with error and stack string. +func showErr(err interface{}, ctx *context.Context, stack string) { + t, _ := template.New("beegoerrortemp").Parse(tpl) + data := map[string]string{ + "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), + "RequestMethod": ctx.Input.Method(), + "RequestURL": ctx.Input.URI(), + "RemoteAddr": ctx.Input.IP(), + "Stack": stack, + "BeegoVersion": VERSION, + "GoVersion": runtime.Version(), + } + t.Execute(ctx.ResponseWriter, data) +} + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +type errorInfo struct { + controllerType reflect.Type + handler http.HandlerFunc + method string + errorType int +} + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = make(map[string]*errorInfo, 10) + +// show 401 unauthorized error. +func unauthorized(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 402 Payment Required +func paymentRequired(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 403 forbidden error. +func forbidden(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    Your address may be blocked"+ + "
    The site may be disabled"+ + "
    You need to log in"+ + "
", + ) +} + +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    '_xsrf' argument missing from POST"+ + "
", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    expected XSRF not found"+ + "
", + ) +} + +// show 404 not found error. +func notFound(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The page has moved"+ + "
    The page no longer exists"+ + "
    You were looking for your puppy and got lost"+ + "
    You like 404 pages"+ + "
", + ) +} + +// show 405 Method Not Allowed +func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ + "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ + "
", + ) +} + +// show 500 internal server error. +func internalServerError(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 501 Not Implemented. +func notImplemented(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 502 Bad Gateway. +func badGateway(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

    "+ + "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 503 service unavailable error. +func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The page is overloaded"+ + "
    Please try again later."+ + "
", + ) +} + +// show 504 Gateway Timeout. +func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ + "
    Please try again later."+ + "
", + ) +} + +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { + t, _ := template.New("beegoerrortemp").Parse(errtpl) + data := M{ + "Title": http.StatusText(errCode), + "BeegoVersion": VERSION, + "Content": template.HTML(errContent), + } + t.Execute(rw, data) +} + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + ErrorMaps[code] = &errorInfo{ + errorType: errorTypeHandler, + handler: h, + method: code, + } + return BeeApp +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + for i := 0; i < rt.NumMethod(); i++ { + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { + errName := strings.TrimPrefix(methodName, "Error") + ErrorMaps[errName] = &errorInfo{ + errorType: errorTypeController, + controllerType: ct, + method: methodName, + } + } + } + return BeeApp +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + exception(strconv.FormatUint(errCode, 10), ctx) +} + +// show error string as simple text message. +// if error string is empty, show 503 or 500 error as default. +func exception(errCode string, ctx *context.Context) { + atoi := func(code string) int { + v, err := strconv.Atoi(code) + if err == nil { + return v + } + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status + } + + for _, ec := range []string{errCode, "503", "500"} { + if h, ok := ErrorMaps[ec]; ok { + executeError(h, ctx, atoi(ec)) + return + } + } + //if 50x error has been removed from errorMap + ctx.ResponseWriter.WriteHeader(atoi(errCode)) + ctx.WriteString(errCode) +} + +func executeError(err *errorInfo, ctx *context.Context, code int) { + //make sure to log the error in the access log + LogAccess(ctx, nil, code) + + if err.errorType == errorTypeHandler { + ctx.ResponseWriter.WriteHeader(code) + err.handler(ctx.ResponseWriter, ctx.Request) + return + } + if err.errorType == errorTypeController { + ctx.Output.SetStatus(code) + //Invoke the request handler + vc := reflect.New(err.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + //call the controller init function + execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) + + //call prepare function + execController.Prepare() + + execController.URLMapping() + + method := vc.MethodByName(err.method) + method.Call([]reflect.Value{}) + + //render template + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + panic(err) + } + } + + // finish all runrouter. release resource + execController.Finish() + } +} diff --git a/pkg/error_test.go b/pkg/error_test.go new file mode 100644 index 00000000..378aa953 --- /dev/null +++ b/pkg/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(w.Body.String(), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if w.Body.String() != parseCodeError { + t.Fail() + } +} diff --git a/pkg/filter.go b/pkg/filter.go new file mode 100644 index 00000000..9cc6e913 --- /dev/null +++ b/pkg/filter.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import "github.com/astaxie/beego/context" + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter struct { + filterFunc FilterFunc + tree *Tree + pattern string + returnOnOutput bool + resetParams bool +} + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } + } + return false +} diff --git a/pkg/filter_test.go b/pkg/filter_test.go new file mode 100644 index 00000000..4ca4d2b8 --- /dev/null +++ b/pkg/filter_test.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/context" +) + +var FilterUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) +} + +func TestFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) + handler.Add("/person/:last/:first", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am astaXie" { + t.Errorf("user define func can't run") + } +} + +var FilterAdminUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am admin")) +} + +// Filter pattern /admin/:all +// all url like /admin/ /admin/xie will all get filter + +func TestPatternTwo(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/ can't run") + } +} + +func TestPatternThree(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/astaxie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/astaxie can't run") + } +} diff --git a/pkg/flash.go b/pkg/flash.go new file mode 100644 index 00000000..a6485a17 --- /dev/null +++ b/pkg/flash.go @@ -0,0 +1,110 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "net/url" + "strings" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData struct { + Data map[string]string +} + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return &FlashData{ + Data: make(map[string]string), + } +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data[key] = msg + } else { + fd.Data[key] = fmt.Sprintf(msg, args...) + } +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["success"] = msg + } else { + fd.Data["success"] = fmt.Sprintf(msg, args...) + } +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["notice"] = msg + } else { + fd.Data["notice"] = fmt.Sprintf(msg, args...) + } +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["warning"] = msg + } else { + fd.Data["warning"] = fmt.Sprintf(msg, args...) + } +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["error"] = msg + } else { + fd.Data["error"] = fmt.Sprintf(msg, args...) + } +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + c.Data["flash"] = fd.Data + var flashValue string + for key, value := range fd.Data { + flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" + } + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + flash := NewFlash() + if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { + v, _ := url.QueryUnescape(cookie.Value) + vals := strings.Split(v, "\x00") + for _, v := range vals { + if len(v) > 0 { + kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") + if len(kv) == 2 { + flash.Data[kv[0]] = kv[1] + } + } + } + //read one time then delete it + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") + } + c.Data["flash"] = flash.Data + return flash +} diff --git a/pkg/flash_test.go b/pkg/flash_test.go new file mode 100644 index 00000000..d5e9608d --- /dev/null +++ b/pkg/flash_test.go @@ -0,0 +1,54 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type TestFlashController struct { + Controller +} + +func (t *TestFlashController) TestWriteFlash() { + flash := NewFlash() + flash.Notice("TestFlashString") + flash.Store(&t.Controller) + // we choose to serve json because we don't want to load a template html file + t.ServeJSON(true) +} + +func TestFlashHeader(t *testing.T) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.ServeHTTP(w, r) + + // get the Set-Cookie value + sc := w.Header().Get("Set-Cookie") + // match for the expected header + res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") + // validate the assertion + if !res { + t.Errorf("TestFlashHeader() unable to validate flash message") + } +} diff --git a/pkg/fs.go b/pkg/fs.go new file mode 100644 index 00000000..41cc6f6e --- /dev/null +++ b/pkg/fs.go @@ -0,0 +1,74 @@ +package beego + +import ( + "net/http" + "os" + "path/filepath" +) + +type FileSystem struct { +} + +func (d FileSystem) Open(name string) (http.File, error) { + return os.Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + + f, err := fs.Open(root) + if err != nil { + return err + } + info, err := f.Stat() + if err != nil { + err = walkFn(root, nil, err) + } else { + err = walk(fs, root, info, walkFn) + } + if err == filepath.SkipDir { + return nil + } + return err +} + +// walk recursively descends path, calling walkFn. +func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { + var err error + if !info.IsDir() { + return walkFn(path, info, nil) + } + + dir, err := fs.Open(path) + if err != nil { + if err1 := walkFn(path, info, err); err1 != nil { + return err1 + } + return err + } + defer dir.Close() + dirs, err := dir.Readdir(-1) + err1 := walkFn(path, info, err) + // If err != nil, walk can't walk into this directory. + // err1 != nil means walkFn want walk to skip this directory or stop walking. + // Therefore, if one of err and err1 isn't nil, walk will return. + if err != nil || err1 != nil { + // The caller's behavior is controlled by the return value, which is decided + // by walkFn. walkFn may ignore err and return nil. + // If walkFn returns SkipDir, it will be handled by the caller. + // So walk should return whatever walkFn returns. + return err1 + } + + for _, fileInfo := range dirs { + filename := filepath.Join(path, fileInfo.Name()) + if err = walk(fs, filename, fileInfo, walkFn); err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + return nil +} diff --git a/pkg/grace/grace.go b/pkg/grace/grace.go new file mode 100644 index 00000000..fb0cb7bb --- /dev/null +++ b/pkg/grace/grace.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "flag" + "net/http" + "os" + "strings" + "sync" + "syscall" + "time" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + regLock *sync.Mutex + runningServers map[string]*Server + runningServersOrder []string + socketPtrOffsetMap map[string]uint + runningServersForked bool + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = 60 * time.Second + + isChild bool + socketOrder string + + hookableSignals []os.Signal +) + +func init() { + flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") + flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + + regLock = &sync.Mutex{} + runningServers = make(map[string]*Server) + runningServersOrder = []string{} + socketPtrOffsetMap = make(map[string]uint) + + hookableSignals = []os.Signal{ + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + } +} + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + regLock.Lock() + defer regLock.Unlock() + + if !flag.Parsed() { + flag.Parse() + } + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + } + + srv = &Server{ + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PreSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + PostSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + }, + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, + } + + runningServersOrder = append(runningServersOrder, addr) + runningServers[addr] = srv + return srv +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/grace/server.go b/pkg/grace/server.go new file mode 100644 index 00000000..008a6171 --- /dev/null +++ b/pkg/grace/server.go @@ -0,0 +1,356 @@ +package grace + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" +) + +// Server embedded http.Server +type Server struct { + *http.Server + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error +} + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + srv.state = StateRunning + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + if shutdownErr := <-srv.terminalChan; shutdownErr != nil { + return shutdownErr + } + return +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + + go srv.handleSignals() + + srv.ln, err = srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(trustFile) + if err != nil { + log.Println(err) + return err + } + pool.AppendCertsFromPEM(data) + srv.TLSConfig.ClientCAs = pool + log.Println("Mutual HTTPS") + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// getListener either opens a new socket to listen on, or takes the acceptor socket +// it got passed when restarted. +func (srv *Server) getListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen(srv.Network, laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// handleSignals listens for os Signals and calls any hooked in function that the +// user had registered with the signal. +func (srv *Server) handleSignals() { + var sig os.Signal + + signal.Notify( + srv.sigChan, + hookableSignals..., + ) + + pid := syscall.Getpid() + for { + sig = <-srv.sigChan + srv.signalHooks(PreSignal, sig) + switch sig { + case syscall.SIGHUP: + log.Println(pid, "Received SIGHUP. forking.") + err := srv.fork() + if err != nil { + log.Println("Fork err:", err) + } + case syscall.SIGINT: + log.Println(pid, "Received SIGINT.") + srv.shutdown() + case syscall.SIGTERM: + log.Println(pid, "Received SIGTERM.") + srv.shutdown() + default: + log.Printf("Received %v: nothing i care about...\n", sig) + } + srv.signalHooks(PostSignal, sig) + } +} + +func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { + if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { + return + } + for _, f := range srv.SignalHooks[ppFlag][sig] { + f() + } +} + +// shutdown closes the listener so that no new connections are accepted. it also +// starts a goroutine that will serverTimeout (stop all running requests) the server +// after DefaultTimeout. +func (srv *Server) shutdown() { + if srv.state != StateRunning { + return + } + + srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() + if DefaultTimeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + } + srv.terminalChan <- srv.Server.Shutdown(ctx) +} + +func (srv *Server) fork() (err error) { + regLock.Lock() + defer regLock.Unlock() + if runningServersForked { + return + } + runningServersForked = true + + var files = make([]*os.File, len(runningServers)) + var orderArgs = make([]string, len(runningServers)) + for _, srvPtr := range runningServers { + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + log.Println(files) + path := os.Args[0] + var args []string + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-graceful" { + break + } + args = append(args, arg) + } + } + args = append(args, "-graceful") + if len(runningServers) > 1 { + args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + log.Println(args) + } + cmd := exec.Command(path, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + err = cmd.Start() + if err != nil { + log.Fatalf("Restart: Failed to launch, error: %v", err) + } + + return +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { + if ppFlag != PreSignal && ppFlag != PostSignal { + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") + return + } + for _, s := range hookableSignals { + if s == sig { + srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) + return + } + } + err = fmt.Errorf("Signal '%v' is not supported", sig) + return +} diff --git a/pkg/hooks.go b/pkg/hooks.go new file mode 100644 index 00000000..49c42d5a --- /dev/null +++ b/pkg/hooks.go @@ -0,0 +1,104 @@ +package beego + +import ( + "encoding/json" + "mime" + "net/http" + "path/filepath" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" +) + +// register MIME type with content type +func registerMime() error { + for k, v := range mimemaps { + mime.AddExtensionType(k, v) + } + return nil +} + +// register default error http handlers, 404,401,403,500 and 503. +func registerDefaultErrorHandler() error { + m := map[string]func(http.ResponseWriter, *http.Request){ + "401": unauthorized, + "402": paymentRequired, + "403": forbidden, + "404": notFound, + "405": methodNotAllowed, + "500": internalServerError, + "501": notImplemented, + "502": badGateway, + "503": serviceUnavailable, + "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, + "413": payloadTooLarge, + } + for e, h := range m { + if _, ok := ErrorMaps[e]; !ok { + ErrorHandler(e, h) + } + } + return nil +} + +func registerSession() error { + if BConfig.WebConfig.Session.SessionOn { + var err error + sessionConfig := AppConfig.String("sessionConfig") + conf := new(session.ManagerConfig) + if sessionConfig == "" { + conf.CookieName = BConfig.WebConfig.Session.SessionName + conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie + conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime + conf.Secure = BConfig.Listen.EnableHTTPS + conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime + conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly + conf.Domain = BConfig.WebConfig.Session.SessionDomain + conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + } else { + if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { + return err + } + } + if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { + return err + } + go GlobalSessions.GC() + } + return nil +} + +func registerTemplate() error { + defer lockViewPaths() + if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { + if BConfig.RunMode == DEV { + logs.Warn(err) + } + return err + } + return nil +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil +} + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/pkg/httplib/README.md b/pkg/httplib/README.md new file mode 100644 index 00000000..97df8e6b --- /dev/null +++ b/pkg/httplib/README.md @@ -0,0 +1,97 @@ +# httplib +httplib is an libs help you to curl remote url. + +# How to use? + +## GET +you can use Get to crawl data. + + import "github.com/astaxie/beego/httplib" + + str, err := httplib.Get("http://beego.me/").String() + if err != nil { + // error + } + fmt.Println(str) + +## POST +POST data to remote url + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.Param("password","123456") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + +## Set timeout + +The default timeout is `60` seconds, function prototype: + + SetTimeout(connectTimeout, readWriteTimeout time.Duration) + +Example: + + // GET + httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + // POST + httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + +## Debug + +If you want to debug the request info, set the debug on + + httplib.Get("http://beego.me/").Debug(true) + +## Set HTTP Basic Auth + + str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() + if err != nil { + // error + } + fmt.Println(str) + +## Set HTTPS + +If request url is https, You can set the client support TSL: + + httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) + +More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config + +## Set HTTP Version + +some servers need to specify the protocol version of HTTP + + httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") + +## Set Cookie + +some http request need setcookie. So set it like this: + + cookie := &http.Cookie{} + cookie.Name = "username" + cookie.Value = "astaxie" + httplib.Get("http://beego.me/").SetCookie(cookie) + +## Upload file + +httplib support mutil file upload, use `req.PostFile()` + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.PostFile("uploadfile1", "httplib.pdf") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + + +See godoc for further documentation and examples. + +* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go new file mode 100644 index 00000000..60aa4e8b --- /dev/null +++ b/pkg/httplib/httplib.go @@ -0,0 +1,654 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httputil" + "net/url" + "os" + "path" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v2" +) + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + DumpBody: true, +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + var resp http.Response + u, err := url.Parse(rawurl) + if err != nil { + log.Println("Httplib:", err) + } + req := http.Request{ + URL: u, + Method: method, + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + return &BeegoHTTPRequest{ + url: rawurl, + req: &req, + params: map[string][]string{}, + files: map[string]string{}, + setting: defaultSetting, + resp: &resp, + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + DumpBody bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration +} + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + url string + req *http.Request + params map[string][]string + files map[string]string + setting BeegoHTTPSettings + resp *http.Response + body []byte + dump []byte +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.req +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.setting = setting + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.req.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.setting.EnableCookie = enable + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.setting.UserAgent = useragent + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.setting.ShowDebug = isdebug + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.setting.Retries = times + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.setting.DumpBody = isdump + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.dump +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.setting.ConnectTimeout = connectTimeout + b.setting.ReadWriteTimeout = readWriteTimeout + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.setting.TLSClientConfig = config + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.req.Header.Set(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.req.Host = host + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + if len(vers) == 0 { + vers = "HTTP/1.1" + } + + major, minor, ok := http.ParseHTTPVersion(vers) + if ok { + b.req.Proto = vers + b.req.ProtoMajor = major + b.req.ProtoMinor = minor + } + + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.req.Header.Add("Cookie", cookie.String()) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.setting.Transport = transport + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.setting.Proxy = proxy + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + if param, ok := b.params[key]; ok { + b.params[key] = append(param, value) + } else { + b.params[key] = []string{value} + } + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.files[formname] = filename + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + switch t := data.(type) { + case string: + bf := bytes.NewBufferString(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + case []byte: + bf := bytes.NewBuffer(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + } + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := xml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/xml") + } + return b, nil +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := yaml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/x+yaml") + } + return b, nil +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := json.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/json") + } + return b, nil +} + +func (b *BeegoHTTPRequest) buildURL(paramBody string) { + // build GET url with query string + if b.req.Method == "GET" && len(paramBody) > 0 { + if strings.Contains(b.url, "?") { + b.url += "&" + paramBody + } else { + b.url = b.url + "?" + paramBody + } + return + } + + // build POST/PUT/PATCH url and body + if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { + // with files + if len(b.files) > 0 { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + if err != nil { + log.Println("Httplib:", err) + } + fh, err := os.Open(filename) + if err != nil { + log.Println("Httplib:", err) + } + //iocopy + _, err = io.Copy(fileWriter, fh) + fh.Close() + if err != nil { + log.Println("Httplib:", err) + } + } + for k, v := range b.params { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + b.Header("Content-Type", bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") + return + } + + // with params + if len(paramBody) > 0 { + b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Body(paramBody) + } + } +} + +func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { + if b.resp.StatusCode != 0 { + return b.resp, nil + } + resp, err := b.DoRequest() + if err != nil { + return nil, err + } + b.resp = resp + return resp, nil +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + + b.buildURL(paramBody) + urlParsed, err := url.Parse(b.url) + if err != nil { + return nil, err + } + + b.req.URL = urlParsed + + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.Dial == nil { + t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + + client := &http.Client{ + Transport: trans, + Jar: jar, + } + + if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { + b.req.Header.Set("User-Agent", b.setting.UserAgent) + } + + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + + if b.setting.ShowDebug { + dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) + if err != nil { + log.Println(err.Error()) + } + b.dump = dump + } + // retries default value is 0, it will run once. + // retries equal to -1, it will run forever until success + // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam + for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { + resp, err = client.Do(b.req) + if err == nil { + break + } + time.Sleep(b.setting.RetryDelay) + } + return resp, err +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + data, err := b.Bytes() + if err != nil { + return "", err + } + + return string(data), nil +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + if b.body != nil { + return b.body, nil + } + resp, err := b.getResponse() + if err != nil { + return nil, err + } + if resp.Body == nil { + return nil, nil + } + defer resp.Body.Close() + if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + b.body, err = ioutil.ReadAll(reader) + return b.body, err + } + b.body, err = ioutil.ReadAll(resp.Body) + return b.body, err +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + resp, err := b.getResponse() + if err != nil { + return err + } + if resp.Body == nil { + return nil + } + defer resp.Body.Close() + err = pathExistAndMkdir(filename) + if err != nil { + return err + } + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +//Check that the file directory exists, there is no automatically created +func pathExistAndMkdir(filename string) (err error) { + filename = path.Dir(filename) + _, err = os.Stat(filename) + if err == nil { + return nil + } + if os.IsNotExist(err) { + err = os.MkdirAll(filename, os.ModePerm) + if err == nil { + return nil + } + } + return err +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(data, v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return yaml.Unmarshal(data, v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.getResponse() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return func(netw, addr string) (net.Conn, error) { + conn, err := net.DialTimeout(netw, addr, cTimeout) + if err != nil { + return nil, err + } + err = conn.SetDeadline(time.Now().Add(rwTimeout)) + return conn, err + } +} diff --git a/pkg/httplib/httplib_test.go b/pkg/httplib/httplib_test.go new file mode 100644 index 00000000..f6be8571 --- /dev/null +++ b/pkg/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +//func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +//} + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/log.go b/pkg/log.go new file mode 100644 index 00000000..cc4c0f81 --- /dev/null +++ b/pkg/log.go @@ -0,0 +1,127 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/logs/README.md b/pkg/logs/README.md new file mode 100644 index 00000000..c05bcc04 --- /dev/null +++ b/pkg/logs/README.md @@ -0,0 +1,72 @@ +## logs +logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/logs + + +## What adapters are supported? + +As of now this logs support console, file,smtp and conn. + + +## How to use it? + +First you must import it + +```golang +import ( + "github.com/astaxie/beego/logs" +) +``` + +Then init a Log (example with console adapter) + +```golang +log := logs.NewLogger(10000) +log.SetLogger("console", "") +``` + +> the first params stand for how many channel + +Use it like this: + +```golang +log.Trace("trace") +log.Info("info") +log.Warn("warning") +log.Debug("debug") +log.Critical("critical") +``` + +## File adapter + +Configure file adapter like this: + +```golang +log := NewLogger(10000) +log.SetLogger("file", `{"filename":"test.log"}`) +``` + +## Conn adapter + +Configure like this: + +```golang +log := NewLogger(1000) +log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) +log.Info("info") +``` + +## Smtp adapter + +Configure like this: + +```golang +log := NewLogger(10000) +log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) +log.Critical("sendmail critical") +time.Sleep(time.Second * 30) +``` diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go new file mode 100644 index 00000000..3ff9e20f --- /dev/null +++ b/pkg/logs/accesslog.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "strings" + "encoding/json" + "fmt" + "time" +) + +const ( + apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" + apacheFormat = "APACHE_FORMAT" + jsonFormat = "JSON_FORMAT" +) + +// AccessLogRecord struct for holding access log data. +type AccessLogRecord struct { + RemoteAddr string `json:"remote_addr"` + RequestTime time.Time `json:"request_time"` + RequestMethod string `json:"request_method"` + Request string `json:"request"` + ServerProtocol string `json:"server_protocol"` + Host string `json:"host"` + Status int `json:"status"` + BodyBytesSent int64 `json:"body_bytes_sent"` + ElapsedTime time.Duration `json:"elapsed_time"` + HTTPReferrer string `json:"http_referrer"` + HTTPUserAgent string `json:"http_user_agent"` + RemoteUser string `json:"remote_user"` +} + +func (r *AccessLogRecord) json() ([]byte, error) { + buffer := &bytes.Buffer{} + encoder := json.NewEncoder(buffer) + disableEscapeHTML(encoder) + + err := encoder.Encode(r) + return buffer.Bytes(), err +} + +func disableEscapeHTML(i interface{}) { + if e, ok := i.(interface { + SetEscapeHTML(bool) + }); ok { + e.SetEscapeHTML(false) + } +} + +// AccessLog - Format and print access log. +func AccessLog(r *AccessLogRecord, format string) { + var msg string + switch format { + case apacheFormat: + timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") + msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, + r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) + case jsonFormat: + fallthrough + default: + jsonData, err := r.json() + if err != nil { + msg = fmt.Sprintf(`{"Error": "%s"}`, err) + } else { + msg = string(jsonData) + } + } + beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) +} diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go new file mode 100644 index 00000000..867ff4cb --- /dev/null +++ b/pkg/logs/alils/alils.go @@ -0,0 +1,186 @@ +package alils + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" +) + +const ( + // CacheSize set the flush size + CacheSize int = 64 + // Delimiter define the topic delimiter + Delimiter string = "##" +) + +// Config is the Config for Ali Log +type Config struct { + Project string `json:"project"` + Endpoint string `json:"endpoint"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` + LogStore string `json:"log_store"` + Topics []string `json:"topics"` + Source string `json:"source"` + Level int `json:"level"` + FlushWhen int `json:"flush_when"` +} + +// aliLSWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type aliLSWriter struct { + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + Config +} + +// NewAliLS create a new Logger +func NewAliLS() logs.Logger { + alils := new(aliLSWriter) + alils.Level = logs.LevelTrace + return alils +} + +// Init parse config and init struct +func (c *aliLSWriter) Init(jsonConfig string) (err error) { + + json.Unmarshal([]byte(jsonConfig), c) + + if c.FlushWhen > CacheSize { + c.FlushWhen = CacheSize + } + + prj := &LogProject{ + Name: c.Project, + Endpoint: c.Endpoint, + AccessKeyID: c.KeyID, + AccessKeySecret: c.KeySecret, + } + + c.store, err = prj.GetLogStore(c.LogStore) + if err != nil { + return err + } + + // Create default Log Group + c.group = append(c.group, &LogGroup{ + Topic: proto.String(""), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + }) + + // Create other Log Group + c.groupMap = make(map[string]*LogGroup) + for _, topic := range c.Topics { + + lg := &LogGroup{ + Topic: proto.String(topic), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + } + + c.group = append(c.group, lg) + c.groupMap[topic] = lg + } + + if len(c.group) == 1 { + c.withMap = false + } else { + c.withMap = true + } + + c.lock = &sync.Mutex{} + + return nil +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { + + if level > c.Level { + return nil + } + + var topic string + var content string + var lg *LogGroup + if c.withMap { + + // Topic,LogGroup + strs := strings.SplitN(msg, Delimiter, 2) + if len(strs) == 2 { + pos := strings.LastIndex(strs[0], " ") + topic = strs[0][pos+1 : len(strs[0])] + content = strs[0][0:pos] + strs[1] + lg = c.groupMap[topic] + } + + // send to empty Topic + if lg == nil { + content = msg + lg = c.group[0] + } + } else { + content = msg + lg = c.group[0] + } + + c1 := &LogContent{ + Key: proto.String("msg"), + Value: proto.String(content), + } + + l := &Log{ + Time: proto.Uint32(uint32(when.Unix())), + Contents: []*LogContent{ + c1, + }, + } + + c.lock.Lock() + lg.Logs = append(lg.Logs, l) + c.lock.Unlock() + + if len(lg.Logs) >= c.FlushWhen { + c.flush(lg) + } + + return nil +} + +// Flush implementing method. empty. +func (c *aliLSWriter) Flush() { + + // flush all group + for _, lg := range c.group { + c.flush(lg) + } +} + +// Destroy destroy connection writer and close tcp listener. +func (c *aliLSWriter) Destroy() { +} + +func (c *aliLSWriter) flush(lg *LogGroup) { + + c.lock.Lock() + defer c.lock.Unlock() + err := c.store.PutLogs(lg) + if err != nil { + return + } + + lg.Logs = make([]*Log, 0, c.FlushWhen) +} + +func init() { + logs.Register(logs.AdapterAliLS, NewAliLS) +} diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go new file mode 100755 index 00000000..e8c24448 --- /dev/null +++ b/pkg/logs/alils/config.go @@ -0,0 +1,13 @@ +package alils + +const ( + version = "0.5.0" // SDK version + signatureMethod = "hmac-sha1" // Signature method + + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the shard. + OffsetNewest = "end" + // OffsetOldest stands for the oldest offset available on the logstore for a + // shard. + OffsetOldest = "begin" +) diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go new file mode 100755 index 00000000..601b0d78 --- /dev/null +++ b/pkg/logs/alils/log.pb.go @@ -0,0 +1,1038 @@ +package alils + +import ( + "fmt" + "io" + "math" + + "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +var ( + // ErrInvalidLengthLog invalid proto + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + // ErrIntOverflowLog overflow + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) + +// Log define the proto Log +type Log struct { + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset the Log +func (m *Log) Reset() { *m = Log{} } + +// String return the Compact Log +func (m *Log) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*Log) ProtoMessage() {} + +// GetTime return the Log's Time +func (m *Log) GetTime() uint32 { + if m != nil && m.Time != nil { + return *m.Time + } + return 0 +} + +// GetContents return the Log's Contents +func (m *Log) GetContents() []*LogContent { + if m != nil { + return m.Contents + } + return nil +} + +// LogContent define the Log content struct +type LogContent struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogContent +func (m *LogContent) Reset() { *m = LogContent{} } + +// String return the compact text +func (m *LogContent) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogContent) ProtoMessage() {} + +// GetKey return the Key +func (m *LogContent) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +// GetValue return the Value +func (m *LogContent) GetValue() string { + if m != nil && m.Value != nil { + return *m.Value + } + return "" +} + +// LogGroup define the logs struct +type LogGroup struct { + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroup +func (m *LogGroup) Reset() { *m = LogGroup{} } + +// String return the compact text +func (m *LogGroup) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroup) ProtoMessage() {} + +// GetLogs return the loggroup logs +func (m *LogGroup) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +// GetReserved return Reserved +func (m *LogGroup) GetReserved() string { + if m != nil && m.Reserved != nil { + return *m.Reserved + } + return "" +} + +// GetTopic return Topic +func (m *LogGroup) GetTopic() string { + if m != nil && m.Topic != nil { + return *m.Topic + } + return "" +} + +// GetSource return Source +func (m *LogGroup) GetSource() string { + if m != nil && m.Source != nil { + return *m.Source + } + return "" +} + +// LogGroupList define the LogGroups +type LogGroupList struct { + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroupList +func (m *LogGroupList) Reset() { *m = LogGroupList{} } + +// String return compact text +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroupList) ProtoMessage() {} + +// GetLogGroups return the LogGroups +func (m *LogGroupList) GetLogGroups() []*LogGroup { + if m != nil { + return m.LogGroups + } + return nil +} + +// Marshal the logs to byte slice +func (m *Log) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo data +func (m *Log) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Time == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) + if len(m.Contents) > 0 { + for _, msg := range m.Contents { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogContent +func (m *LogContent) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo logcontent to data +func (m *LogContent) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Key == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + + if m.Value == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroup +func (m *LogGroup) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroup to data +func (m *LogGroup) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Logs) > 0 { + for _, msg := range m.Logs { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.Reserved != nil { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) + i += copy(data[i:], *m.Reserved) + } + if m.Topic != nil { + data[i] = 0x1a + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Topic))) + i += copy(data[i:], *m.Topic) + } + if m.Source != nil { + data[i] = 0x22 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Source))) + i += copy(data[i:], *m.Source) + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroupList +func (m *LogGroupList) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroupList to data +func (m *LogGroupList) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, msg := range m.LogGroups { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +func encodeFixed64Log(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Log(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintLog(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} + +// Size return the log's size +func (m *Log) Size() (n int) { + var l int + _ = l + if m.Time != nil { + n += 1 + sovLog(uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, e := range m.Contents { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogContent size based on Key and Value +func (m *LogContent) Size() (n int) { + var l int + _ = l + if m.Key != nil { + l = len(*m.Key) + n += 1 + l + sovLog(uint64(l)) + } + if m.Value != nil { + l = len(*m.Value) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroup size based on Logs +func (m *LogGroup) Size() (n int) { + var l int + _ = l + if len(m.Logs) > 0 { + for _, e := range m.Logs { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.Reserved != nil { + l = len(*m.Reserved) + n += 1 + l + sovLog(uint64(l)) + } + if m.Topic != nil { + l = len(*m.Topic) + n += 1 + l + sovLog(uint64(l)) + } + if m.Source != nil { + l = len(*m.Source) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroupList size +func (m *LogGroupList) Size() (n int) { + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, e := range m.LogGroups { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +func sovLog(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozLog(x uint64) (n int) { + return sovLog((x << 1) ^ (x >> 63)) +} + +// Unmarshal data to log +func (m *Log) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Log: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) + } + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + v |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Time = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Contents = append(m.Contents, &LogContent{}) + if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogContent +func (m *LogContent) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Content: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Key = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Value = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroup +func (m *LogGroup) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Logs = append(m.Logs, &Log{}) + if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Reserved = &s + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Topic = &s + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Source = &s + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroupList +func (m *LogGroupList) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.LogGroups = append(m.LogGroups, &LogGroup{}) + if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipLog(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthLog + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipLog(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go new file mode 100755 index 00000000..e8564efb --- /dev/null +++ b/pkg/logs/alils/log_config.go @@ -0,0 +1,42 @@ +package alils + +// InputDetail define log detail +type InputDetail struct { + LogType string `json:"logType"` + LogPath string `json:"logPath"` + FilePattern string `json:"filePattern"` + LocalStorage bool `json:"localStorage"` + TimeFormat string `json:"timeFormat"` + LogBeginRegex string `json:"logBeginRegex"` + Regex string `json:"regex"` + Keys []string `json:"key"` + FilterKeys []string `json:"filterKey"` + FilterRegex []string `json:"filterRegex"` + TopicFormat string `json:"topicFormat"` +} + +// OutputDetail define the output detail +type OutputDetail struct { + Endpoint string `json:"endpoint"` + LogStoreName string `json:"logstoreName"` +} + +// LogConfig define Log Config +type LogConfig struct { + Name string `json:"configName"` + InputType string `json:"inputType"` + InputDetail InputDetail `json:"inputDetail"` + OutputType string `json:"outputType"` + OutputDetail OutputDetail `json:"outputDetail"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// GetAppliedMachineGroup returns applied machine group of this config. +func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { + groupNames, err = c.project.GetAppliedMachineGroups(c.Name) + return +} diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go new file mode 100755 index 00000000..59db8cbf --- /dev/null +++ b/pkg/logs/alils/log_project.go @@ -0,0 +1,819 @@ +/* +Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). + +For more description about SLS, please read this article: +http://gitlab.alibaba-inc.com/sls/doc. +*/ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// Error message in SLS HTTP response. +type errorMessage struct { + Code string `json:"errorCode"` + Message string `json:"errorMessage"` +} + +// LogProject Define the Ali Project detail +type LogProject struct { + Name string // Project name + Endpoint string // IP or hostname of SLS endpoint + AccessKeyID string + AccessKeySecret string +} + +// NewLogProject creates a new SLS project. +func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { + p = &LogProject{ + Name: name, + Endpoint: endpoint, + AccessKeyID: AccessKeyID, + AccessKeySecret: accessKeySecret, + } + return p, nil +} + +// ListLogStore returns all logstore names of project p. +func (p *LogProject) ListLogStore() (storeNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores") + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + LogStores []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + storeNames = body.LogStores + + return +} + +// GetLogStore returns logstore according by logstore name. +func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/logstores/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + s = &LogStore{} + err = json.Unmarshal(buf, s) + if err != nil { + return + } + s.project = p + return +} + +// CreateLogStore creates a new logstore in SLS, +// where name is logstore name, +// and ttl is time-to-live(in day) of logs, +// and shardCnt is the number of shards. +func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteLogStore deletes a logstore according by logstore name. +func (p *LogProject) DeleteLogStore(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/logstores/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// UpdateLogStore updates a logstore according by logstore name, +// obviously we can't modify the logstore name itself. +func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// ListMachineGroup returns machine group name list and the total number of machine groups. +// The offset starts from 0 and the size is the max number of machine groups could be returned. +func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 500 + } + + uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + MachineGroups []string + Count int + Total int + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + m = body.MachineGroups + total = body.Total + + return +} + +// GetMachineGroup retruns machine group according by machine group name. +func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get machine group:%v", name) + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + m = &MachineGroup{} + err = json.Unmarshal(buf, m) + if err != nil { + return + } + m.project = p + return +} + +// CreateMachineGroup creates a new machine group in SLS. +func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/machinegroups", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// UpdateMachineGroup updates a machine group. +func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteMachineGroup deletes machine group according machine group name. +func (p *LogProject) DeleteMachineGroup(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// ListConfig returns config names list and the total number of configs. +// The offset starts from 0 and the size is the max number of configs could be returned. +func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 100 + } + + uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Total int + Configs []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + cfgNames = body.Configs + total = body.Total + return +} + +// GetConfig returns config according by config name. +func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/configs/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + c = &LogConfig{} + err = json.Unmarshal(buf, c) + if err != nil { + return + } + c.project = p + return +} + +// UpdateConfig updates a config. +func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/configs/"+c.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// CreateConfig creates a new config in SLS. +func (p *LogProject) CreateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/configs", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteConfig deletes a config according by config name. +func (p *LogProject) DeleteConfig(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/configs/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetAppliedMachineGroups returns applied machine group names list according config name. +func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/configs/%v/machinegroups", confName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get applied machine groups") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + Machinegroups []string + } + + body := &Body{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + groupNames = body.Machinegroups + return +} + +// GetAppliedConfigs returns applied config names list according machine group name groupName. +func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to applied configs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Cfg struct { + Count int `json:"count"` + Configs []string `json:"configs"` + } + + body := &Cfg{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + confNames = body.Configs + return +} + +// ApplyConfigToMachineGroup applies config to machine group. +func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "PUT", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to apply config to machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// RemoveConfigFromMachineGroup removes config from machine group. +func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "DELETE", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go new file mode 100755 index 00000000..fa502736 --- /dev/null +++ b/pkg/logs/alils/log_store.go @@ -0,0 +1,271 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strconv" + + lz4 "github.com/cloudflare/golz4" + "github.com/gogo/protobuf/proto" +) + +// LogStore Store the logs +type LogStore struct { + Name string `json:"logstoreName"` + TTL int + ShardCount int + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Shard define the Log Shard +type Shard struct { + ShardID int `json:"shardID"` +} + +// ListShards returns shard id list of this logstore. +func (s *LogStore) ListShards() (shardIDs []int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards", s.Name) + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + var shards []*Shard + err = json.Unmarshal(buf, &shards) + if err != nil { + return + } + + for _, v := range shards { + shardIDs = append(shardIDs, v.ShardID) + } + return +} + +// PutLogs put logs into logstore. +// The callers should transform user logs into LogGroup. +func (s *LogStore) PutLogs(lg *LogGroup) (err error) { + body, err := proto.Marshal(lg) + if err != nil { + return + } + + // Compresse body with lz4 + out := make([]byte, lz4.CompressBound(body)) + n, err := lz4.Compress(body, out) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-compresstype": "lz4", + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/x-protobuf", + } + + uri := fmt.Sprintf("/logstores/%v", s.Name) + r, err := request(s.project, "POST", uri, h, out[:n]) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to put logs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetCursor gets log cursor of one shard specified by shardID. +// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". +// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore +func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", + s.Name, shardID, from) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Cursor string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + cursor = body.Cursor + return +} + +// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogsBytes(shardID int, cursor string, + logGroupMaxCount int) (out []byte, nextCursor string, err error) { + + h := map[string]string{ + "x-sls-bodyrawsize": "0", + "Accept": "application/x-protobuf", + "Accept-Encoding": "lz4", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", + s.Name, shardID, cursor, logGroupMaxCount) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + v, ok := r.Header["X-Sls-Compresstype"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-compresstype' header") + return + } + if v[0] != "lz4" { + err = fmt.Errorf("unexpected compress type:%v", v[0]) + return + } + + v, ok = r.Header["X-Sls-Cursor"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-cursor' header") + return + } + nextCursor = v[0] + + v, ok = r.Header["X-Sls-Bodyrawsize"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") + return + } + bodyRawSize, err := strconv.Atoi(v[0]) + if err != nil { + return + } + + out = make([]byte, bodyRawSize) + err = lz4.Uncompress(buf, out) + if err != nil { + return + } + + return +} + +// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { + + gl = &LogGroupList{} + err = proto.Unmarshal(data, gl) + if err != nil { + return + } + + return +} + +// GetLogs gets logs from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogs(shardID int, cursor string, + logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { + + out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) + if err != nil { + return + } + + gl, err = LogsBytesDecode(out) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go new file mode 100755 index 00000000..b6c69a14 --- /dev/null +++ b/pkg/logs/alils/machine_group.go @@ -0,0 +1,91 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// MachineGroupAttribute define the Attribute +type MachineGroupAttribute struct { + ExternalName string `json:"externalName"` + TopicName string `json:"groupTopic"` +} + +// MachineGroup define the machine Group +type MachineGroup struct { + Name string `json:"groupName"` + Type string `json:"groupType"` + MachineIDType string `json:"machineIdentifyType"` + MachineIDList []string `json:"machineList"` + + Attribute MachineGroupAttribute `json:"groupAttribute"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Machine define the Machine +type Machine struct { + IP string + UniqueID string `json:"machine-uniqueid"` + UserdefinedID string `json:"userdefined-id"` +} + +// MachineList define the Machine List +type MachineList struct { + Total int + Machines []*Machine +} + +// ListMachines returns machine list of this machine group. +func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) + r, err := request(m.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + body := &MachineList{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + ms = body.Machines + total = body.Total + + return +} + +// GetAppliedConfigs returns applied configs of this machine group. +func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { + confNames, err = m.project.GetAppliedConfigs(m.Name) + return +} diff --git a/pkg/logs/alils/request.go b/pkg/logs/alils/request.go new file mode 100755 index 00000000..50d9c43c --- /dev/null +++ b/pkg/logs/alils/request.go @@ -0,0 +1,62 @@ +package alils + +import ( + "bytes" + "crypto/md5" + "fmt" + "net/http" +) + +// request sends a request to SLS. +func request(project *LogProject, method, uri string, headers map[string]string, + body []byte) (resp *http.Response, err error) { + + // The caller should provide 'x-sls-bodyrawsize' header + if _, ok := headers["x-sls-bodyrawsize"]; !ok { + err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + return + } + + // SLS public request headers + headers["Host"] = project.Name + "." + project.Endpoint + headers["Date"] = nowRFC1123() + headers["x-sls-apiversion"] = version + headers["x-sls-signaturemethod"] = signatureMethod + if body != nil { + bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) + headers["Content-MD5"] = bodyMD5 + + if _, ok := headers["Content-Type"]; !ok { + err = fmt.Errorf("Can't find 'Content-Type' header") + return + } + } + + // Calc Authorization + // Authorization = "SLS :" + digest, err := signature(project, method, uri, headers) + if err != nil { + return + } + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) + headers["Authorization"] = auth + + // Initialize http request + reader := bytes.NewReader(body) + urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) + req, err := http.NewRequest(method, urlStr, reader) + if err != nil { + return + } + for k, v := range headers { + req.Header.Add(k, v) + } + + // Get ready to do request + resp, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/signature.go b/pkg/logs/alils/signature.go new file mode 100755 index 00000000..2d611307 --- /dev/null +++ b/pkg/logs/alils/signature.go @@ -0,0 +1,111 @@ +package alils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net/url" + "sort" + "strings" + "time" +) + +// GMT location +var gmtLoc = time.FixedZone("GMT", 0) + +// NowRFC1123 returns now time in RFC1123 format with GMT timezone, +// eg. "Mon, 02 Jan 2006 15:04:05 GMT". +func nowRFC1123() string { + return time.Now().In(gmtLoc).Format(time.RFC1123) +} + +// signature calculates a request's signature digest. +func signature(project *LogProject, method, uri string, + headers map[string]string) (digest string, err error) { + var contentMD5, contentType, date, canoHeaders, canoResource string + var slsHeaderKeys sort.StringSlice + + // SignString = VERB + "\n" + // + CONTENT-MD5 + "\n" + // + CONTENT-TYPE + "\n" + // + DATE + "\n" + // + CanonicalizedSLSHeaders + "\n" + // + CanonicalizedResource + + if val, ok := headers["Content-MD5"]; ok { + contentMD5 = val + } + + if val, ok := headers["Content-Type"]; ok { + contentType = val + } + + date, ok := headers["Date"] + if !ok { + err = fmt.Errorf("Can't find 'Date' header") + return + } + + // Calc CanonicalizedSLSHeaders + slsHeaders := make(map[string]string, len(headers)) + for k, v := range headers { + l := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(l, "x-sls-") { + slsHeaders[l] = strings.TrimSpace(v) + slsHeaderKeys = append(slsHeaderKeys, l) + } + } + + sort.Sort(slsHeaderKeys) + for i, k := range slsHeaderKeys { + canoHeaders += k + ":" + slsHeaders[k] + if i+1 < len(slsHeaderKeys) { + canoHeaders += "\n" + } + } + + // Calc CanonicalizedResource + u, err := url.Parse(uri) + if err != nil { + return + } + + canoResource += url.QueryEscape(u.Path) + if u.RawQuery != "" { + var keys sort.StringSlice + + vals := u.Query() + for k := range vals { + keys = append(keys, k) + } + + sort.Sort(keys) + canoResource += "?" + for i, k := range keys { + if i > 0 { + canoResource += "&" + } + + for _, v := range vals[k] { + canoResource += k + "=" + v + } + } + } + + signStr := method + "\n" + + contentMD5 + "\n" + + contentType + "\n" + + date + "\n" + + canoHeaders + "\n" + + canoResource + + // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) + mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) + _, err = mac.Write([]byte(signStr)) + if err != nil { + return + } + digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return +} diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go new file mode 100644 index 00000000..74c458ab --- /dev/null +++ b/pkg/logs/conn.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "io" + "net" + "time" +) + +// connWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type connWriter struct { + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` +} + +// NewConn create new ConnWrite returning as LoggerInterface. +func NewConn() Logger { + conn := new(connWriter) + conn.Level = LevelTrace + return conn +} + +// Init init connection writer with json config. +// json config only need key "level". +func (c *connWriter) Init(jsonConfig string) error { + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.needToConnectOnMsg() { + err := c.connect() + if err != nil { + return err + } + } + + if c.ReconnectOnMsg { + defer c.innerWriter.Close() + } + + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } + return nil +} + +// Flush implementing method. empty. +func (c *connWriter) Flush() { + +} + +// Destroy destroy connection writer and close tcp listener. +func (c *connWriter) Destroy() { + if c.innerWriter != nil { + c.innerWriter.Close() + } +} + +func (c *connWriter) connect() error { + if c.innerWriter != nil { + c.innerWriter.Close() + c.innerWriter = nil + } + + conn, err := net.Dial(c.Net, c.Addr) + if err != nil { + return err + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + } + + c.innerWriter = conn + c.lg = newLogWriter(conn) + return nil +} + +func (c *connWriter) needToConnectOnMsg() bool { + if c.Reconnect { + return true + } + + if c.innerWriter == nil { + return true + } + + return c.ReconnectOnMsg +} + +func init() { + Register(AdapterConn, NewConn) +} diff --git a/pkg/logs/conn_test.go b/pkg/logs/conn_test.go new file mode 100644 index 00000000..bb377d41 --- /dev/null +++ b/pkg/logs/conn_test.go @@ -0,0 +1,79 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net" + "os" + "testing" +) + +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + +func TestConn(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) + log.Informational("informational") +} + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} diff --git a/pkg/logs/console.go b/pkg/logs/console.go new file mode 100644 index 00000000..3dcaee1d --- /dev/null +++ b/pkg/logs/console.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "os" + "strings" + "time" + + "github.com/shiena/ansicolor" +) + +// brush is a color join function +type brush func(string) string + +// newBrush return a fix color Brush +func newBrush(color string) brush { + pre := "\033[" + reset := "\033[0m" + return func(text string) string { + return pre + color + "m" + text + reset + } +} + +var colors = []brush{ + newBrush("1;37"), // Emergency white + newBrush("1;36"), // Alert cyan + newBrush("1;35"), // Critical magenta + newBrush("1;31"), // Error red + newBrush("1;33"), // Warning yellow + newBrush("1;32"), // Notice green + newBrush("1;34"), // Informational blue + newBrush("1;44"), // Debug Background blue +} + +// consoleWriter implements LoggerInterface and writes messages to terminal. +type consoleWriter struct { + lg *logWriter + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color +} + +// NewConsole create ConsoleWriter returning as LoggerInterface. +func NewConsole() Logger { + cw := &consoleWriter{ + lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), + Level: LevelDebug, + Colorful: true, + } + return cw +} + +// Init init console logger. +// jsonConfig like '{"level":LevelTrace}'. +func (c *consoleWriter) Init(jsonConfig string) error { + if len(jsonConfig) == 0 { + return nil + } + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in console. +func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.Colorful { + msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) + } + c.lg.writeln(when, msg) + return nil +} + +// Destroy implementing method. empty. +func (c *consoleWriter) Destroy() { + +} + +// Flush implementing method. empty. +func (c *consoleWriter) Flush() { + +} + +func init() { + Register(AdapterConsole, NewConsole) +} diff --git a/pkg/logs/console_test.go b/pkg/logs/console_test.go new file mode 100644 index 00000000..4bc45f57 --- /dev/null +++ b/pkg/logs/console_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +// Try each log level in decreasing order of priority. +func testConsoleCalls(bl *BeeLogger) { + bl.Emergency("emergency") + bl.Alert("alert") + bl.Critical("critical") + bl.Error("error") + bl.Warning("warning") + bl.Notice("notice") + bl.Informational("informational") + bl.Debug("debug") +} + +// Test console logging by visually comparing the lines being output with and +// without a log level specification. +func TestConsole(t *testing.T) { + log1 := NewLogger(10000) + log1.EnableFuncCallDepth(true) + log1.SetLogger("console", "") + testConsoleCalls(log1) + + log2 := NewLogger(100) + log2.SetLogger("console", `{"level":3}`) + testConsoleCalls(log2) +} + +// Test console without color +func TestConsoleNoColor(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console", `{"color":false}`) + testConsoleCalls(log) +} + +// Test console async +func TestConsoleAsync(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console") + log.Async() + //log.Close() + testConsoleCalls(log) + for len(log.msgChan) != 0 { + time.Sleep(1 * time.Millisecond) + } +} diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go new file mode 100644 index 00000000..2b7b1710 --- /dev/null +++ b/pkg/logs/es/es.go @@ -0,0 +1,102 @@ +package es + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/elastic/go-elasticsearch/v6" + "github.com/elastic/go-elasticsearch/v6/esapi" + + "github.com/astaxie/beego/logs" +) + +// NewES return a LoggerInterface +func NewES() logs.Logger { + cw := &esLogger{ + Level: logs.LevelDebug, + } + return cw +} + +// esLogger will log msg into ES +// before you using this implementation, +// please import this package +// usually means that you can import this package in your main package +// for example, anonymous: +// import _ "github.com/astaxie/beego/logs/es" +type esLogger struct { + *elasticsearch.Client + DSN string `json:"dsn"` + Level int `json:"level"` +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(jsonconfig string) error { + err := json.Unmarshal([]byte(jsonconfig), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else { + conn, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{el.DSN}, + }) + if err != nil { + return err + } + el.Client = conn + } + return nil +} + +// WriteMsg will write the msg and level into es +func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { + if level > el.Level { + return nil + } + + idx := LogDocument{ + Timestamp: when.Format(time.RFC3339), + Msg: msg, + } + + body, err := json.Marshal(idx) + if err != nil { + return err + } + req := esapi.IndexRequest{ + Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), + DocumentType: "logs", + Body: strings.NewReader(string(body)), + } + _, err = req.Do(context.Background(), el.Client) + return err +} + +// Destroy is a empty method +func (el *esLogger) Destroy() { +} + +// Flush is a empty method +func (el *esLogger) Flush() { + +} + +type LogDocument struct { + Timestamp string `json:"timestamp"` + Msg string `json:"msg"` +} + +func init() { + logs.Register(logs.AdapterEs, NewES) +} diff --git a/pkg/logs/file.go b/pkg/logs/file.go new file mode 100644 index 00000000..222db989 --- /dev/null +++ b/pkg/logs/file.go @@ -0,0 +1,409 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// fileLogWriter implements LoggerInterface. +// It writes messages by lines limit, file size limit, or time frequency. +type fileLogWriter struct { + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + // The opened file + Filename string `json:"filename"` + fileWriter *os.File + + // Rotate at line + MaxLines int `json:"maxlines"` + maxLinesCurLines int + + MaxFiles int `json:"maxfiles"` + MaxFilesCurFiles int + + // Rotate at size + MaxSize int `json:"maxsize"` + maxSizeCurSize int + + // Rotate daily + Daily bool `json:"daily"` + MaxDays int64 `json:"maxdays"` + dailyOpenDate int + dailyOpenTime time.Time + + // Rotate hourly + Hourly bool `json:"hourly"` + MaxHours int64 `json:"maxhours"` + hourlyOpenDate int + hourlyOpenTime time.Time + + Rotate bool `json:"rotate"` + + Level int `json:"level"` + + Perm string `json:"perm"` + + RotatePerm string `json:"rotateperm"` + + fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix +} + +// newFileWriter create a FileLogWriter returning as LoggerInterface. +func newFileWriter() Logger { + w := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Hourly: false, + MaxHours: 168, + Rotate: true, + RotatePerm: "0440", + Level: LevelTrace, + Perm: "0660", + MaxLines: 10000000, + MaxFiles: 999, + MaxSize: 1 << 28, + } + return w +} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":10000, +// "maxsize":1024, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":"0600" +// } +func (w *fileLogWriter) Init(jsonConfig string) error { + err := json.Unmarshal([]byte(jsonConfig), w) + if err != nil { + return err + } + if len(w.Filename) == 0 { + return errors.New("jsonconfig must have filename") + } + w.suffix = filepath.Ext(w.Filename) + w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) + if w.suffix == "" { + w.suffix = ".log" + } + err = w.startLogger() + return err +} + +// start file logger. create log file and set to locker-inside file writer. +func (w *fileLogWriter) startLogger() error { + file, err := w.createLogFile() + if err != nil { + return err + } + if w.fileWriter != nil { + w.fileWriter.Close() + } + w.fileWriter = file + return w.initFd() +} + +func (w *fileLogWriter) needRotateDaily(size int, day int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Daily && day != w.dailyOpenDate) +} + +func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Hourly && hour != w.hourlyOpenDate) + +} + +// WriteMsg write logger message into file. +func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > w.Level { + return nil + } + hd, d, h := formatTimeHeader(when) + msg = string(hd) + msg + "\n" + if w.Rotate { + w.RLock() + if w.needRotateHourly(len(msg), h) { + w.RUnlock() + w.Lock() + if w.needRotateHourly(len(msg), h) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else if w.needRotateDaily(len(msg), d) { + w.RUnlock() + w.Lock() + if w.needRotateDaily(len(msg), d) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else { + w.RUnlock() + } + } + + w.Lock() + _, err := w.fileWriter.Write([]byte(msg)) + if err == nil { + w.maxLinesCurLines++ + w.maxSizeCurSize += len(msg) + } + w.Unlock() + return err +} + +func (w *fileLogWriter) createLogFile() (*os.File, error) { + // Open the log file + perm, err := strconv.ParseInt(w.Perm, 8, 64) + if err != nil { + return nil, err + } + + filepath := path.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(perm)) + + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) + if err == nil { + // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask + os.Chmod(w.Filename, os.FileMode(perm)) + } + return fd, err +} + +func (w *fileLogWriter) initFd() error { + fd := w.fileWriter + fInfo, err := fd.Stat() + if err != nil { + return fmt.Errorf("get stat err: %s", err) + } + w.maxSizeCurSize = int(fInfo.Size()) + w.dailyOpenTime = time.Now() + w.dailyOpenDate = w.dailyOpenTime.Day() + w.hourlyOpenTime = time.Now() + w.hourlyOpenDate = w.hourlyOpenTime.Hour() + w.maxLinesCurLines = 0 + if w.Hourly { + go w.hourlyRotate(w.hourlyOpenTime) + } else if w.Daily { + go w.dailyRotate(w.dailyOpenTime) + } + if fInfo.Size() > 0 && w.MaxLines > 0 { + count, err := w.lines() + if err != nil { + return err + } + w.maxLinesCurLines = count + } + return nil +} + +func (w *fileLogWriter) dailyRotate(openTime time.Time) { + y, m, d := openTime.Add(24 * time.Hour).Date() + nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateDaily(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) hourlyRotate(openTime time.Time) { + y, m, d := openTime.Add(1 * time.Hour).Date() + h, _, _ := openTime.Add(1 * time.Hour).Clock() + nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateHourly(0, time.Now().Hour()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) lines() (int, error) { + fd, err := os.Open(w.Filename) + if err != nil { + return 0, err + } + defer fd.Close() + + buf := make([]byte, 32768) // 32k + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := fd.Read(buf) + if err != nil && err != io.EOF { + return count, err + } + + count += bytes.Count(buf[:c], lineSep) + + if err == io.EOF { + break + } + } + + return count, nil +} + +// DoRotate means it need to write file in new file. +// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) +func (w *fileLogWriter) doRotate(logTime time.Time) error { + // file exists + // Find the next available number + num := w.MaxFilesCurFiles + 1 + fName := "" + format := "" + var openTime time.Time + rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) + if err != nil { + return err + } + + _, err = os.Lstat(w.Filename) + if err != nil { + //even if the file is not exist or other ,we should RESTART the logger + goto RESTART_LOGGER + } + + if w.Hourly { + format = "2006010215" + openTime = w.hourlyOpenTime + } else if w.Daily { + format = "2006-01-02" + openTime = w.dailyOpenTime + } + + // only when one of them be setted, then the file would be splited + if w.MaxLines > 0 || w.MaxSize > 0 { + for ; err == nil && num <= w.MaxFiles; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + } + } else { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + w.MaxFilesCurFiles = num + } + + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) + } + + // close fileWriter before rename + w.fileWriter.Close() + + // Rename the file to its new found name + // even if occurs error,we MUST guarantee to restart new logger + err = os.Rename(w.Filename, fName) + if err != nil { + goto RESTART_LOGGER + } + + err = os.Chmod(fName, os.FileMode(rotatePerm)) + +RESTART_LOGGER: + + startLoggerErr := w.startLogger() + go w.deleteOldLog() + + if startLoggerErr != nil { + return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) + } + if err != nil { + return fmt.Errorf("Rotate: %s", err) + } + return nil +} + +func (w *fileLogWriter) deleteOldLog() { + dir := filepath.Dir(w.Filename) + absolutePath, err := filepath.EvalSymlinks(w.Filename) + if err == nil { + dir = filepath.Dir(absolutePath) + } + filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) + } + }() + + if info == nil { + return + } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } + return + }) +} + +// Destroy close the file description, close file writer. +func (w *fileLogWriter) Destroy() { + w.fileWriter.Close() +} + +// Flush flush file logger. +// there are no buffering messages in file logger in memory. +// flush file means sync file from disk. +func (w *fileLogWriter) Flush() { + w.fileWriter.Sync() +} + +func init() { + Register(AdapterFile, newFileWriter) +} diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go new file mode 100644 index 00000000..e7c2ca9a --- /dev/null +++ b/pkg/logs/file_test.go @@ -0,0 +1,420 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "fmt" + "io/ioutil" + "os" + "strconv" + "testing" + "time" +) + +func TestFilePerm(t *testing.T) { + log := NewLogger(10000) + // use 0666 as test perm cause the default umask is 022 + log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + file, err := os.Stat("test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0666 { + t.Fatal("unexpected log file permission") + } + os.Remove("test.log") +} + +func TestFile1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test.log"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelDebug + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test.log") +} + +func TestFile2(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError)) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test2.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelError + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test2.log") +} + +func TestFileDailyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileDailyRotate_02(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) +} + +func TestFileDailyRotate_03(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) + os.Remove(fn) +} + +func TestFileDailyRotate_04(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) +} + +func TestFileDailyRotate_05(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) + os.Remove(fn) +} +func TestFileDailyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_02(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) +} + +func TestFileHourlyRotate_03(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) + os.Remove(fn) +} + +func TestFileHourlyRotate_04(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) +} + +func TestFileHourlyRotate_05(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) + os.Remove(fn) +} + +func TestFileHourlyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { + fw := &fileLogWriter{ + Daily: daily, + MaxDays: 7, + Hourly: hourly, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } + + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } + + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.Log(err) + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileDailyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) + today = today.Add(-1 * time.Second) + fw.dailyRotate(today) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Hourly: true, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() + hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location()) + hour = hour.Add(-1 * time.Second) + fw.hourlyRotate(hour) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func BenchmarkFile(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronous(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronousCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileOnGoroutine(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + go log.Debug("debug") + } + os.Remove("test4.log") +} diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go new file mode 100644 index 00000000..88ba0f9a --- /dev/null +++ b/pkg/logs/jianliao.go @@ -0,0 +1,72 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type JLWriter struct { + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` +} + +// newJLWriter create jiaoliao writer. +func newJLWriter() Logger { + return &JLWriter{Level: LevelTrace} +} + +// Init JLWriter with json config string +func (s *JLWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("authorName", s.AuthorName) + form.Add("title", s.Title) + form.Add("text", text) + if s.RedirectURL != "" { + form.Add("redirectUrl", s.RedirectURL) + } + if s.ImageURL != "" { + form.Add("imageUrl", s.ImageURL) + } + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *JLWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *JLWriter) Destroy() { +} + +func init() { + Register(AdapterJianLiao, newJLWriter) +} diff --git a/pkg/logs/log.go b/pkg/logs/log.go new file mode 100644 index 00000000..39c006d2 --- /dev/null +++ b/pkg/logs/log.go @@ -0,0 +1,669 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logs provide a general log interface +// Usage: +// +// import "github.com/astaxie/beego/logs" +// +// log := NewLogger(10000) +// log.SetLogger("console", "") +// +// > the first params stand for how many channel +// +// Use it like this: +// +// log.Trace("trace") +// log.Info("info") +// log.Warn("warning") +// log.Debug("debug") +// log.Critical("critical") +// +// more docs http://beego.me/docs/module/logs.md +package logs + +import ( + "fmt" + "log" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +// RFC5424 log message levels. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "smtp" + AdapterConn = "conn" + AdapterEs = "es" + AdapterJianLiao = "jianliao" + AdapterSlack = "slack" + AdapterAliLS = "alils" +) + +// Legacy log level constants to ensure backwards compatibility. +const ( + LevelInfo = LevelInformational + LevelTrace = LevelDebug + LevelWarn = LevelWarning +) + +type newLoggerFunc func() Logger + +// Logger defines the behavior of a log provider. +type Logger interface { + Init(config string) error + WriteMsg(when time.Time, msg string, level int) error + Destroy() + Flush() +} + +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} + +// Register makes a log provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, log newLoggerFunc) { + if log == nil { + panic("logs: Register provide is nil") + } + if _, dup := adapters[name]; dup { + panic("logs: Register called twice for provider " + name) + } + adapters[name] = log +} + +// BeeLogger is default logger in beego application. +// it can contain several providers and log message into all providers. +type BeeLogger struct { + lock sync.Mutex + level int + init bool + enableFuncCallDepth bool + loggerFuncCallDepth int + asynchronous bool + prefix string + msgChanLen int64 + msgChan chan *logMsg + signalChan chan string + wg sync.WaitGroup + outputs []*nameLogger +} + +const defaultAsyncMsgLen = 1e3 + +type nameLogger struct { + Logger + name string +} + +type logMsg struct { + level int + msg string + when time.Time +} + +var logMsgPool *sync.Pool + +// NewLogger returns a new BeeLogger. +// channelLen means the number of messages in chan(used where asynchronous is true). +// if the buffering chan is full, logger adapters write to file or other way. +func NewLogger(channelLens ...int64) *BeeLogger { + bl := new(BeeLogger) + bl.level = LevelDebug + bl.loggerFuncCallDepth = 2 + bl.msgChanLen = append(channelLens, 0)[0] + if bl.msgChanLen <= 0 { + bl.msgChanLen = defaultAsyncMsgLen + } + bl.signalChan = make(chan string, 1) + bl.setLogger(AdapterConsole) + return bl +} + +// Async set the log to asynchronous and start the goroutine +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + bl.lock.Lock() + defer bl.lock.Unlock() + if bl.asynchronous { + return bl + } + bl.asynchronous = true + if len(msgLen) > 0 && msgLen[0] > 0 { + bl.msgChanLen = msgLen[0] + } + bl.msgChan = make(chan *logMsg, bl.msgChanLen) + logMsgPool = &sync.Pool{ + New: func() interface{} { + return &logMsg{} + }, + } + bl.wg.Add(1) + go bl.startLogger() + return bl +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) + } + } + + logAdapter, ok := adapters[adapterName] + if !ok { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + + lg := logAdapter() + err := lg.Init(config) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) + return nil +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLogger(adapterName, configs...) +} + +// DelLogger remove a logger adapter in BeeLogger. +func (bl *BeeLogger) DelLogger(adapterName string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + outputs := []*nameLogger{} + for _, lg := range bl.outputs { + if lg.name == adapterName { + lg.Destroy() + } else { + outputs = append(outputs, lg) + } + } + if len(outputs) == len(bl.outputs) { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + bl.outputs = outputs + return nil +} + +func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { + for _, l := range bl.outputs { + err := l.WriteMsg(when, msg, level) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) + } + } +} + +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // writeMsg will always add a '\n' character + if p[len(p)-1] == '\n' { + p = p[0 : len(p)-1] + } + // set levelLoggerImpl to ensure all log message will be write out + err = bl.writeMsg(levelLoggerImpl, string(p)) + if err == nil { + return len(p), err + } + return 0, err +} + +func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { + if !bl.init { + bl.lock.Lock() + bl.setLogger(AdapterConsole) + bl.lock.Unlock() + } + + if len(v) > 0 { + msg = fmt.Sprintf(msg, v...) + } + + msg = bl.prefix + " " + msg + + when := time.Now() + if bl.enableFuncCallDepth { + _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 + } + _, filename := path.Split(file) + msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg + } + + //set level info in front of filename info + if logLevel == levelLoggerImpl { + // set to emergency to ensure all log will be print out correctly + logLevel = LevelEmergency + } else { + msg = levelPrefix[logLevel] + " " + msg + } + + if bl.asynchronous { + lm := logMsgPool.Get().(*logMsg) + lm.level = logLevel + lm.msg = msg + lm.when = when + if bl.outputs != nil { + bl.msgChan <- lm + } else { + logMsgPool.Put(lm) + } + } else { + bl.writeToLoggers(when, msg, logLevel) + } + return nil +} + +// SetLevel Set log message level. +// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), +// log providers will not even be sent the message. +func (bl *BeeLogger) SetLevel(l int) { + bl.level = l +} + +// GetLevel Get Current log message level. +func (bl *BeeLogger) GetLevel() int { + return bl.level +} + +// SetLogFuncCallDepth set log funcCallDepth +func (bl *BeeLogger) SetLogFuncCallDepth(d int) { + bl.loggerFuncCallDepth = d +} + +// GetLogFuncCallDepth return log funcCallDepth for wrapper +func (bl *BeeLogger) GetLogFuncCallDepth() int { + return bl.loggerFuncCallDepth +} + +// EnableFuncCallDepth enable log funcCallDepth +func (bl *BeeLogger) EnableFuncCallDepth(b bool) { + bl.enableFuncCallDepth = b +} + +// set prefix +func (bl *BeeLogger) SetPrefix(s string) { + bl.prefix = s +} + +// start logger chan reading. +// when chan is not empty, write logs. +func (bl *BeeLogger) startLogger() { + gameOver := false + for { + select { + case bm := <-bl.msgChan: + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + case sg := <-bl.signalChan: + // Now should only send "flush" or "close" to bl.signalChan + bl.flush() + if sg == "close" { + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + gameOver = true + } + bl.wg.Done() + } + if gameOver { + break + } + } +} + +// Emergency Log EMERGENCY level message. +func (bl *BeeLogger) Emergency(format string, v ...interface{}) { + if LevelEmergency > bl.level { + return + } + bl.writeMsg(LevelEmergency, format, v...) +} + +// Alert Log ALERT level message. +func (bl *BeeLogger) Alert(format string, v ...interface{}) { + if LevelAlert > bl.level { + return + } + bl.writeMsg(LevelAlert, format, v...) +} + +// Critical Log CRITICAL level message. +func (bl *BeeLogger) Critical(format string, v ...interface{}) { + if LevelCritical > bl.level { + return + } + bl.writeMsg(LevelCritical, format, v...) +} + +// Error Log ERROR level message. +func (bl *BeeLogger) Error(format string, v ...interface{}) { + if LevelError > bl.level { + return + } + bl.writeMsg(LevelError, format, v...) +} + +// Warning Log WARNING level message. +func (bl *BeeLogger) Warning(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Notice Log NOTICE level message. +func (bl *BeeLogger) Notice(format string, v ...interface{}) { + if LevelNotice > bl.level { + return + } + bl.writeMsg(LevelNotice, format, v...) +} + +// Informational Log INFORMATIONAL level message. +func (bl *BeeLogger) Informational(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Debug Log DEBUG level message. +func (bl *BeeLogger) Debug(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Warn Log WARN level message. +// compatibility alias for Warning() +func (bl *BeeLogger) Warn(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Info Log INFO level message. +// compatibility alias for Informational() +func (bl *BeeLogger) Info(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Trace Log TRACE level message. +// compatibility alias for Debug() +func (bl *BeeLogger) Trace(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Flush flush all chan data. +func (bl *BeeLogger) Flush() { + if bl.asynchronous { + bl.signalChan <- "flush" + bl.wg.Wait() + bl.wg.Add(1) + return + } + bl.flush() +} + +// Close close logger, flush all chan data and destroy all adapters in BeeLogger. +func (bl *BeeLogger) Close() { + if bl.asynchronous { + bl.signalChan <- "close" + bl.wg.Wait() + close(bl.msgChan) + } else { + bl.flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + } + close(bl.signalChan) +} + +// Reset close all outputs, and set bl.outputs to nil +func (bl *BeeLogger) Reset() { + bl.Flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil +} + +func (bl *BeeLogger) flush() { + if bl.asynchronous { + for { + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + continue + } + break + } + } + for _, l := range bl.outputs { + l.Flush() + } +} + +// beeLogger references the used application logger. +var beeLogger = NewLogger() + +// GetBeeLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return beeLogger +} + +var beeLoggerMap = struct { + sync.RWMutex + logs map[string]*log.Logger +}{ + logs: map[string]*log.Logger{}, +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + prefix := append(prefixes, "")[0] + if prefix != "" { + prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) + } + beeLoggerMap.RLock() + l, ok := beeLoggerMap.logs[prefix] + if ok { + beeLoggerMap.RUnlock() + return l + } + beeLoggerMap.RUnlock() + beeLoggerMap.Lock() + defer beeLoggerMap.Unlock() + l, ok = beeLoggerMap.logs[prefix] + if !ok { + l = log.New(beeLogger, prefix, 0) + beeLoggerMap.logs[prefix] = l + } + return l +} + +// Reset will remove all the adapter +func Reset() { + beeLogger.Reset() +} + +// Async set the beelogger with Async mode and hold msglen messages +func Async(msgLen ...int64) *BeeLogger { + return beeLogger.Async(msgLen...) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + beeLogger.SetLevel(l) +} + +// SetPrefix sets the prefix +func SetPrefix(s string) { + beeLogger.SetPrefix(s) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + beeLogger.enableFuncCallDepth = b +} + +// SetLogFuncCall set the CallDepth, default is 4 +func SetLogFuncCall(b bool) { + beeLogger.EnableFuncCallDepth(b) + beeLogger.SetLogFuncCallDepth(4) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + beeLogger.loggerFuncCallDepth = d +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + return beeLogger.SetLogger(adapter, config...) +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + beeLogger.Emergency(formatLog(f, v...)) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + beeLogger.Alert(formatLog(f, v...)) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + beeLogger.Critical(formatLog(f, v...)) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + beeLogger.Error(formatLog(f, v...)) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + beeLogger.Notice(formatLog(f, v...)) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + beeLogger.Debug(formatLog(f, v...)) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + beeLogger.Trace(formatLog(f, v...)) +} + +func formatLog(f interface{}, v ...interface{}) string { + var msg string + switch f.(type) { + case string: + msg = f.(string) + if len(v) == 0 { + return msg + } + if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { + //format string + } else { + //do not contain format char + msg += strings.Repeat(" %v", len(v)) + } + default: + msg = fmt.Sprint(f) + if len(v) == 0 { + return msg + } + msg += strings.Repeat(" %v", len(v)) + } + return fmt.Sprintf(msg, v...) +} diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go new file mode 100644 index 00000000..a28bff6f --- /dev/null +++ b/pkg/logs/logger.go @@ -0,0 +1,176 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "io" + "runtime" + "sync" + "time" +) + +type logWriter struct { + sync.Mutex + writer io.Writer +} + +func newLogWriter(wr io.Writer) *logWriter { + return &logWriter{writer: wr} +} + +func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { + lg.Lock() + h, _, _ := formatTimeHeader(when) + n, err := lg.writer.Write(append(append(h, msg...), '\n')) + lg.Unlock() + return n, err +} + +const ( + y1 = `0123456789` + y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` + y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + mo1 = `000000000111` + mo2 = `123456789012` + d1 = `0000000001111111111222222222233` + d2 = `1234567890123456789012345678901` + h1 = `000000000011111111112222` + h2 = `012345678901234567890123` + mi1 = `000000000011111111112222222222333333333344444444445555555555` + mi2 = `012345678901234567890123456789012345678901234567890123456789` + s1 = `000000000011111111112222222222333333333344444444445555555555` + s2 = `012345678901234567890123456789012345678901234567890123456789` + ns1 = `0123456789` +) + +func formatTimeHeader(when time.Time) ([]byte, int, int) { + y, mo, d := when.Date() + h, mi, s := when.Clock() + ns := when.Nanosecond() / 1000000 + //len("2006/01/02 15:04:05.123 ")==24 + var buf [24]byte + + buf[0] = y1[y/1000%10] + buf[1] = y2[y/100] + buf[2] = y3[y-y/100*100] + buf[3] = y4[y-y/100*100] + buf[4] = '/' + buf[5] = mo1[mo-1] + buf[6] = mo2[mo-1] + buf[7] = '/' + buf[8] = d1[d-1] + buf[9] = d2[d-1] + buf[10] = ' ' + buf[11] = h1[h] + buf[12] = h2[h] + buf[13] = ':' + buf[14] = mi1[mi] + buf[15] = mi2[mi] + buf[16] = ':' + buf[17] = s1[s] + buf[18] = s2[s] + buf[19] = '.' + buf[20] = ns1[ns/100] + buf[21] = ns1[ns%100/10] + buf[22] = ns1[ns%10] + + buf[23] = ' ' + + return buf[0:], d, h +} + +var ( + green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) + white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) + yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) + red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) + blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) + magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) + cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) + + w32Green = string([]byte{27, 91, 52, 50, 109}) + w32White = string([]byte{27, 91, 52, 55, 109}) + w32Yellow = string([]byte{27, 91, 52, 51, 109}) + w32Red = string([]byte{27, 91, 52, 49, 109}) + w32Blue = string([]byte{27, 91, 52, 52, 109}) + w32Magenta = string([]byte{27, 91, 52, 53, 109}) + w32Cyan = string([]byte{27, 91, 52, 54, 109}) + + reset = string([]byte{27, 91, 48, 109}) +) + +var once sync.Once +var colorMap map[string]string + +func initColor() { + if runtime.GOOS == "windows" { + green = w32Green + white = w32White + yellow = w32Yellow + red = w32Red + blue = w32Blue + magenta = w32Magenta + cyan = w32Cyan + } + colorMap = map[string]string{ + //by color + "green": green, + "white": white, + "yellow": yellow, + "red": red, + //by method + "GET": blue, + "POST": cyan, + "PUT": yellow, + "DELETE": red, + "PATCH": green, + "HEAD": magenta, + "OPTIONS": white, + } +} + +// ColorByStatus return color by http code +// 2xx return Green +// 3xx return White +// 4xx return Yellow +// 5xx return Red +func ColorByStatus(code int) string { + once.Do(initColor) + switch { + case code >= 200 && code < 300: + return colorMap["green"] + case code >= 300 && code < 400: + return colorMap["white"] + case code >= 400 && code < 500: + return colorMap["yellow"] + default: + return colorMap["red"] + } +} + +// ColorByMethod return color by http code +func ColorByMethod(method string) string { + once.Do(initColor) + if c := colorMap[method]; c != "" { + return c + } + return reset +} + +// ResetColor return reset color +func ResetColor() string { + return reset +} diff --git a/pkg/logs/logger_test.go b/pkg/logs/logger_test.go new file mode 100644 index 00000000..15be500d --- /dev/null +++ b/pkg/logs/logger_test.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestFormatHeader_0(t *testing.T) { + tm := time.Now() + if tm.Year() >= 2100 { + t.FailNow() + } + dur := time.Second + for { + if tm.Year() >= 2100 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + dur *= 2 + } +} + +func TestFormatHeader_1(t *testing.T) { + tm := time.Now() + year := tm.Year() + dur := time.Second + for { + if tm.Year() >= year+1 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + } +} diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go new file mode 100644 index 00000000..90168274 --- /dev/null +++ b/pkg/logs/multifile.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "time" +) + +// A filesLogWriter manages several fileLogWriter +// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file +// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log +// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log +// the rotate attribute also acts like fileLogWriter +type multiFileLogWriter struct { + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` +} + +var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":0, +// "maxsize":0, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":0600, +// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], +// } + +func (f *multiFileLogWriter) Init(config string) error { + writer := newFileWriter().(*fileLogWriter) + err := writer.Init(config) + if err != nil { + return err + } + f.fullLogWriter = writer + f.writers[LevelDebug+1] = writer + + //unmarshal "separate" field to f.Separate + json.Unmarshal([]byte(config), f) + + jsonMap := map[string]interface{}{} + json.Unmarshal([]byte(config), &jsonMap) + + for i := LevelEmergency; i < LevelDebug+1; i++ { + for _, v := range f.Separate { + if v == levelNames[i] { + jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix + jsonMap["level"] = i + bs, _ := json.Marshal(jsonMap) + writer = newFileWriter().(*fileLogWriter) + err := writer.Init(string(bs)) + if err != nil { + return err + } + f.writers[i] = writer + } + } + } + + return nil +} + +func (f *multiFileLogWriter) Destroy() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Destroy() + } + } +} + +func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if f.fullLogWriter != nil { + f.fullLogWriter.WriteMsg(when, msg, level) + } + for i := 0; i < len(f.writers)-1; i++ { + if f.writers[i] != nil { + if level == f.writers[i].Level { + f.writers[i].WriteMsg(when, msg, level) + } + } + } + return nil +} + +func (f *multiFileLogWriter) Flush() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Flush() + } + } +} + +// newFilesWriter create a FileLogWriter returning as LoggerInterface. +func newFilesWriter() Logger { + return &multiFileLogWriter{} +} + +func init() { + Register(AdapterMultiFile, newFilesWriter) +} diff --git a/pkg/logs/multifile_test.go b/pkg/logs/multifile_test.go new file mode 100644 index 00000000..57b96094 --- /dev/null +++ b/pkg/logs/multifile_test.go @@ -0,0 +1,78 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "os" + "strconv" + "strings" + "testing" +) + +func TestFiles_1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + fns := []string{""} + fns = append(fns, levelNames[0:]...) + name := "test" + suffix := ".log" + for _, fn := range fns { + + file := name + suffix + if fn != "" { + file = name + "." + fn + suffix + } + f, err := os.Open(file) + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + lastLine := "" + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lastLine = string(line) + lineNum++ + } + } + var expected = 1 + if fn == "" { + expected = LevelDebug + 1 + } + if lineNum != expected { + t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines") + } + if lineNum == 1 { + if !strings.Contains(lastLine, fn) { + t.Fatal(file + " " + lastLine + " not contains the log msg " + fn) + } + } + os.Remove(file) + } + +} diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go new file mode 100644 index 00000000..1cd2e5ae --- /dev/null +++ b/pkg/logs/slack.go @@ -0,0 +1,60 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type SLACKWriter struct { + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` +} + +// newSLACKWriter create jiaoliao writer. +func newSLACKWriter() Logger { + return &SLACKWriter{Level: LevelTrace} +} + +// Init SLACKWriter with json config string +func (s *SLACKWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("payload", text) + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *SLACKWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SLACKWriter) Destroy() { +} + +func init() { + Register(AdapterSlack, newSLACKWriter) +} diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go new file mode 100644 index 00000000..6208d7b8 --- /dev/null +++ b/pkg/logs/smtp.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/smtp" + "strings" + "time" +) + +// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. +type SMTPWriter struct { + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Subject string `json:"subject"` + FromAddress string `json:"fromAddress"` + RecipientAddresses []string `json:"sendTos"` + Level int `json:"level"` +} + +// NewSMTPWriter create smtp writer. +func newSMTPWriter() Logger { + return &SMTPWriter{Level: LevelTrace} +} + +// Init smtp writer with json config. +// config like: +// { +// "username":"example@gmail.com", +// "password:"password", +// "host":"smtp.gmail.com:465", +// "subject":"email title", +// "fromAddress":"from@example.com", +// "sendTos":["email1","email2"], +// "level":LevelError +// } +func (s *SMTPWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { + if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { + return nil + } + return smtp.PlainAuth( + "", + s.Username, + s.Password, + host, + ) +} + +func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { + client, err := smtp.Dial(hostAddressWithPort) + if err != nil { + return err + } + + host, _, _ := net.SplitHostPort(hostAddressWithPort) + tlsConn := &tls.Config{ + InsecureSkipVerify: true, + ServerName: host, + } + if err = client.StartTLS(tlsConn); err != nil { + return err + } + + if auth != nil { + if err = client.Auth(auth); err != nil { + return err + } + } + + if err = client.Mail(fromAddress); err != nil { + return err + } + + for _, rec := range recipients { + if err = client.Rcpt(rec); err != nil { + return err + } + } + + w, err := client.Data() + if err != nil { + return err + } + _, err = w.Write(msgContent) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return client.Quit() +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + hp := strings.Split(s.Host, ":") + + // Set up authentication information. + auth := s.getSMTPAuth(hp[0]) + + // Connect to the server, authenticate, set the sender and recipient, + // and send the email all in one step. + contentType := "Content-Type: text/plain" + "; charset=UTF-8" + mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) + + return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) +} + +// Flush implementing method. empty. +func (s *SMTPWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SMTPWriter) Destroy() { +} + +func init() { + Register(AdapterMail, newSMTPWriter) +} diff --git a/pkg/logs/smtp_test.go b/pkg/logs/smtp_test.go new file mode 100644 index 00000000..28e762d2 --- /dev/null +++ b/pkg/logs/smtp_test.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestSmtp(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) + log.Critical("sendmail critical") + time.Sleep(time.Second * 30) +} diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go new file mode 100644 index 00000000..7722240b --- /dev/null +++ b/pkg/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": beego.BConfig.ServerName, + "env": beego.BConfig.RunMode, + "appname": beego.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": beego.BConfig.AppName, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := beego.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/metric/prometheus_test.go b/pkg/metric/prometheus_test.go new file mode 100644 index 00000000..d82a6dec --- /dev/null +++ b/pkg/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} diff --git a/pkg/migration/ddl.go b/pkg/migration/ddl.go new file mode 100644 index 00000000..cd2c1c49 --- /dev/null +++ b/pkg/migration/ddl.go @@ -0,0 +1,395 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "fmt" + + "github.com/astaxie/beego/logs" +) + +// Index struct defines the structure of Index Columns +type Index struct { + Name string +} + +// Unique struct defines a single unique key combination +type Unique struct { + Definition string + Columns []*Column +} + +//Column struct defines a single column of a table +type Column struct { + Name string + Inc string + Null string + Default string + Unsign string + DataType string + remove bool + Modify bool +} + +// Foreign struct defines a single foreign relationship +type Foreign struct { + ForeignTable string + ForeignColumn string + OnDelete string + OnUpdate string + Column +} + +// RenameColumn struct allows renaming of columns +type RenameColumn struct { + OldName string + OldNull string + OldDefault string + OldUnsign string + OldDataType string + NewName string + Column +} + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + m.TableName = tablename + m.Engine = engine + m.Charset = charset + m.ModifyType = "create" +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + m.TableName = tablename + m.ModifyType = "alter" +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + return col +} + +//PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + m.AddPrimary(col) + return col +} + +//UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + + uniqueOriginal := &Unique{} + + for _, unique := range m.Uniques { + if unique.Definition == uni { + unique.AddColumnsToUnique(col) + uniqueOriginal = unique + } + } + if uniqueOriginal.Definition == "" { + unique := &Unique{Definition: uni} + unique.AddColumnsToUnique(col) + m.AddUnique(unique) + } + + return col +} + +//ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + + foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} + foreign.Name = colname + m.AddForeign(foreign) + return foreign +} + +//SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + foreign.OnDelete = "ON DELETE" + del + return foreign +} + +//SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + foreign.OnUpdate = "ON UPDATE" + update + return foreign +} + +//Remove marks the columns to be removed. +//it allows reverse m to create the column. +func (c *Column) Remove() { + c.remove = true +} + +//SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + if inc { + c.Inc = "auto_increment" + } + return c +} + +//SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + if null { + c.Null = "" + + } else { + c.Null = "NOT NULL" + } + return c +} + +//SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + c.Default = "DEFAULT " + def + return c +} + +//SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + if unsign { + c.Unsign = "UNSIGNED" + } + return c +} + +//SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + c.DataType = dataType + return c +} + +//SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + if null { + c.OldNull = "" + + } else { + c.OldNull = "NOT NULL" + } + return c +} + +//SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + c.OldDefault = def + return c +} + +//SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + if unsign { + c.OldUnsign = "UNSIGNED" + } + return c +} + +//SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + c.OldDataType = dataType + return c +} + +//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + m.Primary = append(m.Primary, c) + return c +} + +//AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + + unique.Columns = append(unique.Columns, columns...) + + return unique +} + +//AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + + m.Columns = append(m.Columns, columns...) + + return m +} + +//AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + m.Primary = append(m.Primary, primary) + return m +} + +//AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + m.Uniques = append(m.Uniques, unique) + return m +} + +//AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + m.Foreigns = append(m.Foreigns, foreign) + return m +} + +//AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + m.Indexes = append(m.Indexes, index) + return m +} + +//RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + rename := &RenameColumn{OldName: from, NewName: to} + m.Renames = append(m.Renames, rename) + return rename +} + +//GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + sql = "" + switch m.ModifyType { + case "create": + { + sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName) + for index, column := range m.Columns { + sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf(",\n PRIMARY KEY( ") + } + for index, column := range m.Primary { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(m.Primary) > index+1 { + sql += "," + } + + } + if len(m.Primary) > 0 { + sql += fmt.Sprintf(")") + } + + for _, unique := range m.Uniques { + sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition) + for index, column := range unique.Columns { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(unique.Columns) > index+1 { + sql += "," + } + } + sql += fmt.Sprintf(")") + } + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + + } + sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset) + break + } + case "alter": + { + sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) + for index, column := range m.Columns { + if !column.remove { + logs.Info("col") + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + + if len(m.Columns) > index+1 { + sql += "," + } + } + for index, column := range m.Renames { + sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for index, foreign := range m.Foreigns { + sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + if len(m.Foreigns) > index+1 { + sql += "," + } + } + sql += ";" + + break + } + case "reverse": + { + + sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName) + for index, column := range m.Columns { + if column.remove { + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf("\n DROP PRIMARY KEY,") + } + + for index, unique := range m.Uniques { + sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) + if len(m.Uniques) > index+1 { + sql += "," + } + + } + for index, column := range m.Renames { + sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name) + } + sql += ";" + } + case "delete": + { + sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName) + } + } + + return +} diff --git a/pkg/migration/doc.go b/pkg/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go new file mode 100644 index 00000000..5ddfd972 --- /dev/null +++ b/pkg/migration/migration.go @@ -0,0 +1,330 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "errors" + "sort" + "strings" + "time" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/orm" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +//Migration defines the migrations by either SQL or DDL +type Migration struct { + sqls []string + Created string + TableName string + Engine string + Charset string + ModifyType string + Columns []*Column + Indexes []*Index + Primary []*Column + Uniques []*Unique + Foreigns []*Foreign + Renames []*RenameColumn + RemoveColumns []*Column + RemoveIndexes []*Index + RemoveUniques []*Unique + RemoveForeigns []*Foreign +} + +var ( + migrationMap map[string]Migrationer +) + +func init() { + migrationMap = make(map[string]Migrationer) +} + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + + switch m.ModifyType { + case "reverse": + m.ModifyType = "alter" + case "delete": + m.ModifyType = "create" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + + switch m.ModifyType { + case "alter": + m.ModifyType = "reverse" + case "create": + m.ModifyType = "delete" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +//Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + m.ModifyType = migrationType + m.sqls = append(m.sqls, m.GetSQL()) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + m.sqls = append(m.sqls, sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + m.sqls = make([]string, 0) +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + o := orm.NewOrm() + for _, s := range m.sqls { + logs.Info("exec sql:", s) + r := o.Raw(s) + _, err := r.Exec() + if err != nil { + return err + } + } + return m.addOrUpdateRecord(name, status) +} + +func (m *Migration) addOrUpdateRecord(name, status string) error { + o := orm.NewOrm() + if status == "down" { + status = "rollback" + p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() + if err != nil { + return nil + } + _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) + return err + } + status = "update" + p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() + if err != nil { + return err + } + _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) + return err +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + t, err := time.Parse(DateFormat, m.Created) + if err != nil { + return 0 + } + return t.Unix() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + if _, ok := migrationMap[name]; ok { + return errors.New("already exist name:" + name) + } + migrationMap[name] = m + return nil +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + sm := sortMap(migrationMap) + i := 0 + migs, _ := getAllMigrations() + for _, v := range sm { + if _, ok := migs[v.name]; !ok { + logs.Info("start upgrade", v.name) + v.m.Reset() + v.m.Up() + err := v.m.Exec(v.name, "up") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end upgrade:", v.name) + i++ + } + } + logs.Info("total success upgrade:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + if v, ok := migrationMap[name]; ok { + logs.Info("start rollback") + v.Reset() + v.Down() + err := v.Exec(name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end rollback") + time.Sleep(2 * time.Second) + return nil + } + logs.Error("not exist the migrationMap name:" + name) + time.Sleep(2 * time.Second) + return errors.New("not exist the migrationMap name:" + name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + sm := sortMap(migrationMap) + i := 0 + for j := len(sm) - 1; j >= 0; j-- { + v := sm[j] + if isRollBack(v.name) { + logs.Info("skip the", v.name) + time.Sleep(1 * time.Second) + continue + } + logs.Info("start reset:", v.name) + v.m.Reset() + v.m.Down() + err := v.m.Exec(v.name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + i++ + logs.Info("end reset:", v.name) + } + logs.Info("total success reset:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + err := Reset() + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + err = Upgrade(0) + return err +} + +type dataSlice []data + +type data struct { + created int64 + name string + m Migrationer +} + +// Len is part of sort.Interface. +func (d dataSlice) Len() int { + return len(d) +} + +// Swap is part of sort.Interface. +func (d dataSlice) Swap(i, j int) { + d[i], d[j] = d[j], d[i] +} + +// Less is part of sort.Interface. We use count as the value to sort by +func (d dataSlice) Less(i, j int) bool { + return d[i].created < d[j].created +} + +func sortMap(m map[string]Migrationer) dataSlice { + s := make(dataSlice, 0, len(m)) + for k, v := range m { + d := data{} + d.created = v.GetCreated() + d.name = k + d.m = v + s = append(s, d) + } + sort.Sort(s) + return s +} + +func isRollBack(name string) bool { + o := orm.NewOrm() + var maps []orm.Params + num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return false + } + if num <= 0 { + return false + } + if maps[0]["status"] == "rollback" { + return true + } + return false +} +func getAllMigrations() (map[string]string, error) { + o := orm.NewOrm() + var maps []orm.Params + migs := make(map[string]string) + num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return migs, err + } + if num > 0 { + for _, v := range maps { + name := v["name"].(string) + migs[name] = v["status"].(string) + } + } + return migs, nil +} diff --git a/pkg/mime.go b/pkg/mime.go new file mode 100644 index 00000000..ca2878ab --- /dev/null +++ b/pkg/mime.go @@ -0,0 +1,556 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var mimemaps = map[string]string{ + ".3dm": "x-world/x-3dmf", + ".3dmf": "x-world/x-3dmf", + ".7z": "application/x-7z-compressed", + ".a": "application/octet-stream", + ".aab": "application/x-authorware-bin", + ".aam": "application/x-authorware-map", + ".aas": "application/x-authorware-seg", + ".abc": "text/vndabc", + ".ace": "application/x-ace-compressed", + ".acgi": "text/html", + ".afl": "video/animaflex", + ".ai": "application/postscript", + ".aif": "audio/aiff", + ".aifc": "audio/aiff", + ".aiff": "audio/aiff", + ".aim": "application/x-aim", + ".aip": "text/x-audiosoft-intra", + ".alz": "application/x-alz-compressed", + ".ani": "application/x-navi-animation", + ".aos": "application/x-nokia-9000-communicator-add-on-software", + ".aps": "application/mime", + ".apk": "application/vnd.android.package-archive", + ".arc": "application/x-arc-compressed", + ".arj": "application/arj", + ".art": "image/x-jg", + ".asf": "video/x-ms-asf", + ".asm": "text/x-asm", + ".asp": "text/asp", + ".asx": "application/x-mplayer2", + ".au": "audio/basic", + ".avi": "video/x-msvideo", + ".avs": "video/avs-video", + ".bcpio": "application/x-bcpio", + ".bin": "application/mac-binary", + ".bmp": "image/bmp", + ".boo": "application/book", + ".book": "application/book", + ".boz": "application/x-bzip2", + ".bsh": "application/x-bsh", + ".bz2": "application/x-bzip2", + ".bz": "application/x-bzip", + ".c++": "text/plain", + ".c": "text/x-c", + ".cab": "application/vnd.ms-cab-compressed", + ".cat": "application/vndms-pkiseccat", + ".cc": "text/x-c", + ".ccad": "application/clariscad", + ".cco": "application/x-cocoa", + ".cdf": "application/cdf", + ".cer": "application/pkix-cert", + ".cha": "application/x-chat", + ".chat": "application/x-chat", + ".chrt": "application/vnd.kde.kchart", + ".class": "application/java", + ".com": "text/plain", + ".conf": "text/plain", + ".cpio": "application/x-cpio", + ".cpp": "text/x-c", + ".cpt": "application/mac-compactpro", + ".crl": "application/pkcs-crl", + ".crt": "application/pkix-cert", + ".crx": "application/x-chrome-extension", + ".csh": "text/x-scriptcsh", + ".css": "text/css", + ".csv": "text/csv", + ".cxx": "text/plain", + ".dar": "application/x-dar", + ".dcr": "application/x-director", + ".deb": "application/x-debian-package", + ".deepv": "application/x-deepv", + ".def": "text/plain", + ".der": "application/x-x509-ca-cert", + ".dif": "video/x-dv", + ".dir": "application/x-director", + ".divx": "video/divx", + ".dl": "video/dl", + ".dmg": "application/x-apple-diskimage", + ".doc": "application/msword", + ".dot": "application/msword", + ".dp": "application/commonground", + ".drw": "application/drafting", + ".dump": "application/octet-stream", + ".dv": "video/x-dv", + ".dvi": "application/x-dvi", + ".dwf": "drawing/x-dwf=(old)", + ".dwg": "application/acad", + ".dxf": "application/dxf", + ".dxr": "application/x-director", + ".el": "text/x-scriptelisp", + ".elc": "application/x-bytecodeelisp=(compiled=elisp)", + ".eml": "message/rfc822", + ".env": "application/x-envoy", + ".eps": "application/postscript", + ".es": "application/x-esrehber", + ".etx": "text/x-setext", + ".evy": "application/envoy", + ".exe": "application/octet-stream", + ".f77": "text/x-fortran", + ".f90": "text/x-fortran", + ".f": "text/x-fortran", + ".fdf": "application/vndfdf", + ".fif": "application/fractals", + ".fli": "video/fli", + ".flo": "image/florian", + ".flv": "video/x-flv", + ".flx": "text/vndfmiflexstor", + ".fmf": "video/x-atomic3d-feature", + ".for": "text/x-fortran", + ".fpx": "image/vndfpx", + ".frl": "application/freeloader", + ".funk": "audio/make", + ".g3": "image/g3fax", + ".g": "text/plain", + ".gif": "image/gif", + ".gl": "video/gl", + ".gsd": "audio/x-gsm", + ".gsm": "audio/x-gsm", + ".gsp": "application/x-gsp", + ".gss": "application/x-gss", + ".gtar": "application/x-gtar", + ".gz": "application/x-compressed", + ".gzip": "application/x-gzip", + ".h": "text/x-h", + ".hdf": "application/x-hdf", + ".help": "application/x-helpfile", + ".hgl": "application/vndhp-hpgl", + ".hh": "text/x-h", + ".hlb": "text/x-script", + ".hlp": "application/hlp", + ".hpg": "application/vndhp-hpgl", + ".hpgl": "application/vndhp-hpgl", + ".hqx": "application/binhex", + ".hta": "application/hta", + ".htc": "text/x-component", + ".htm": "text/html", + ".html": "text/html", + ".htmls": "text/html", + ".htt": "text/webviewhtml", + ".htx": "text/html", + ".ice": "x-conference/x-cooltalk", + ".ico": "image/x-icon", + ".ics": "text/calendar", + ".icz": "text/calendar", + ".idc": "text/plain", + ".ief": "image/ief", + ".iefs": "image/ief", + ".iges": "application/iges", + ".igs": "application/iges", + ".ima": "application/x-ima", + ".imap": "application/x-httpd-imap", + ".inf": "application/inf", + ".ins": "application/x-internett-signup", + ".ip": "application/x-ip2", + ".isu": "video/x-isvideo", + ".it": "audio/it", + ".iv": "application/x-inventor", + ".ivr": "i-world/i-vrml", + ".ivy": "application/x-livescreen", + ".jam": "audio/x-jam", + ".jav": "text/x-java-source", + ".java": "text/x-java-source", + ".jcm": "application/x-java-commerce", + ".jfif-tbnl": "image/jpeg", + ".jfif": "image/jpeg", + ".jnlp": "application/x-java-jnlp-file", + ".jpe": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".jps": "image/x-jps", + ".js": "application/javascript", + ".json": "application/json", + ".jut": "image/jutvision", + ".kar": "audio/midi", + ".karbon": "application/vnd.kde.karbon", + ".kfo": "application/vnd.kde.kformula", + ".flw": "application/vnd.kde.kivio", + ".kml": "application/vnd.google-earth.kml+xml", + ".kmz": "application/vnd.google-earth.kmz", + ".kon": "application/vnd.kde.kontour", + ".kpr": "application/vnd.kde.kpresenter", + ".kpt": "application/vnd.kde.kpresenter", + ".ksp": "application/vnd.kde.kspread", + ".kwd": "application/vnd.kde.kword", + ".kwt": "application/vnd.kde.kword", + ".ksh": "text/x-scriptksh", + ".la": "audio/nspaudio", + ".lam": "audio/x-liveaudio", + ".latex": "application/x-latex", + ".lha": "application/lha", + ".lhx": "application/octet-stream", + ".list": "text/plain", + ".lma": "audio/nspaudio", + ".log": "text/plain", + ".lsp": "text/x-scriptlisp", + ".lst": "text/plain", + ".lsx": "text/x-la-asf", + ".ltx": "application/x-latex", + ".lzh": "application/octet-stream", + ".lzx": "application/lzx", + ".m1v": "video/mpeg", + ".m2a": "audio/mpeg", + ".m2v": "video/mpeg", + ".m3u": "audio/x-mpegurl", + ".m": "text/x-m", + ".man": "application/x-troff-man", + ".manifest": "text/cache-manifest", + ".map": "application/x-navimap", + ".mar": "text/plain", + ".mbd": "application/mbedlet", + ".mc$": "application/x-magic-cap-package-10", + ".mcd": "application/mcad", + ".mcf": "text/mcf", + ".mcp": "application/netmc", + ".me": "application/x-troff-me", + ".mht": "message/rfc822", + ".mhtml": "message/rfc822", + ".mid": "application/x-midi", + ".midi": "application/x-midi", + ".mif": "application/x-frame", + ".mime": "message/rfc822", + ".mjf": "audio/x-vndaudioexplosionmjuicemediafile", + ".mjpg": "video/x-motion-jpeg", + ".mm": "application/base64", + ".mme": "application/base64", + ".mod": "audio/mod", + ".moov": "video/quicktime", + ".mov": "video/quicktime", + ".movie": "video/x-sgi-movie", + ".mp2": "audio/mpeg", + ".mp3": "audio/mpeg3", + ".mp4": "video/mp4", + ".mpa": "audio/mpeg", + ".mpc": "application/x-project", + ".mpe": "video/mpeg", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".mpga": "audio/mpeg", + ".mpp": "application/vndms-project", + ".mpt": "application/x-project", + ".mpv": "application/x-project", + ".mpx": "application/x-project", + ".mrc": "application/marc", + ".ms": "application/x-troff-ms", + ".mv": "video/x-sgi-movie", + ".my": "audio/make", + ".mzz": "application/x-vndaudioexplosionmzz", + ".nap": "image/naplps", + ".naplps": "image/naplps", + ".nc": "application/x-netcdf", + ".ncm": "application/vndnokiaconfiguration-message", + ".nif": "image/x-niff", + ".niff": "image/x-niff", + ".nix": "application/x-mix-transfer", + ".nsc": "application/x-conference", + ".nvd": "application/x-navidoc", + ".o": "application/octet-stream", + ".oda": "application/oda", + ".odb": "application/vnd.oasis.opendocument.database", + ".odc": "application/vnd.oasis.opendocument.chart", + ".odf": "application/vnd.oasis.opendocument.formula", + ".odg": "application/vnd.oasis.opendocument.graphics", + ".odi": "application/vnd.oasis.opendocument.image", + ".odm": "application/vnd.oasis.opendocument.text-master", + ".odp": "application/vnd.oasis.opendocument.presentation", + ".ods": "application/vnd.oasis.opendocument.spreadsheet", + ".odt": "application/vnd.oasis.opendocument.text", + ".oga": "audio/ogg", + ".ogg": "audio/ogg", + ".ogv": "video/ogg", + ".omc": "application/x-omc", + ".omcd": "application/x-omcdatamaker", + ".omcr": "application/x-omcregerator", + ".otc": "application/vnd.oasis.opendocument.chart-template", + ".otf": "application/vnd.oasis.opendocument.formula-template", + ".otg": "application/vnd.oasis.opendocument.graphics-template", + ".oth": "application/vnd.oasis.opendocument.text-web", + ".oti": "application/vnd.oasis.opendocument.image-template", + ".otm": "application/vnd.oasis.opendocument.text-master", + ".otp": "application/vnd.oasis.opendocument.presentation-template", + ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", + ".ott": "application/vnd.oasis.opendocument.text-template", + ".p10": "application/pkcs10", + ".p12": "application/pkcs-12", + ".p7a": "application/x-pkcs7-signature", + ".p7c": "application/pkcs7-mime", + ".p7m": "application/pkcs7-mime", + ".p7r": "application/x-pkcs7-certreqresp", + ".p7s": "application/pkcs7-signature", + ".p": "text/x-pascal", + ".part": "application/pro_eng", + ".pas": "text/pascal", + ".pbm": "image/x-portable-bitmap", + ".pcl": "application/vndhp-pcl", + ".pct": "image/x-pict", + ".pcx": "image/x-pcx", + ".pdb": "chemical/x-pdb", + ".pdf": "application/pdf", + ".pfunk": "audio/make", + ".pgm": "image/x-portable-graymap", + ".pic": "image/pict", + ".pict": "image/pict", + ".pkg": "application/x-newton-compatible-pkg", + ".pko": "application/vndms-pkipko", + ".pl": "text/x-scriptperl", + ".plx": "application/x-pixclscript", + ".pm4": "application/x-pagemaker", + ".pm5": "application/x-pagemaker", + ".pm": "text/x-scriptperl-module", + ".png": "image/png", + ".pnm": "application/x-portable-anymap", + ".pot": "application/mspowerpoint", + ".pov": "model/x-pov", + ".ppa": "application/vndms-powerpoint", + ".ppm": "image/x-portable-pixmap", + ".pps": "application/mspowerpoint", + ".ppt": "application/mspowerpoint", + ".ppz": "application/mspowerpoint", + ".pre": "application/x-freelance", + ".prt": "application/pro_eng", + ".ps": "application/postscript", + ".psd": "application/octet-stream", + ".pvu": "paleovu/x-pv", + ".pwz": "application/vndms-powerpoint", + ".py": "text/x-scriptphyton", + ".pyc": "application/x-bytecodepython", + ".qcp": "audio/vndqcelp", + ".qd3": "x-world/x-3dmf", + ".qd3d": "x-world/x-3dmf", + ".qif": "image/x-quicktime", + ".qt": "video/quicktime", + ".qtc": "video/x-qtc", + ".qti": "image/x-quicktime", + ".qtif": "image/x-quicktime", + ".ra": "audio/x-pn-realaudio", + ".ram": "audio/x-pn-realaudio", + ".rar": "application/x-rar-compressed", + ".ras": "application/x-cmu-raster", + ".rast": "image/cmu-raster", + ".rexx": "text/x-scriptrexx", + ".rf": "image/vndrn-realflash", + ".rgb": "image/x-rgb", + ".rm": "application/vndrn-realmedia", + ".rmi": "audio/mid", + ".rmm": "audio/x-pn-realaudio", + ".rmp": "audio/x-pn-realaudio", + ".rng": "application/ringing-tones", + ".rnx": "application/vndrn-realplayer", + ".roff": "application/x-troff", + ".rp": "image/vndrn-realpix", + ".rpm": "audio/x-pn-realaudio-plugin", + ".rt": "text/vndrn-realtext", + ".rtf": "text/richtext", + ".rtx": "text/richtext", + ".rv": "video/vndrn-realvideo", + ".s": "text/x-asm", + ".s3m": "audio/s3m", + ".s7z": "application/x-7z-compressed", + ".saveme": "application/octet-stream", + ".sbk": "application/x-tbook", + ".scm": "text/x-scriptscheme", + ".sdml": "text/plain", + ".sdp": "application/sdp", + ".sdr": "application/sounder", + ".sea": "application/sea", + ".set": "application/set", + ".sgm": "text/x-sgml", + ".sgml": "text/x-sgml", + ".sh": "text/x-scriptsh", + ".shar": "application/x-bsh", + ".shtml": "text/x-server-parsed-html", + ".sid": "audio/x-psid", + ".skd": "application/x-koan", + ".skm": "application/x-koan", + ".skp": "application/x-koan", + ".skt": "application/x-koan", + ".sit": "application/x-stuffit", + ".sitx": "application/x-stuffitx", + ".sl": "application/x-seelogo", + ".smi": "application/smil", + ".smil": "application/smil", + ".snd": "audio/basic", + ".sol": "application/solids", + ".spc": "text/x-speech", + ".spl": "application/futuresplash", + ".spr": "application/x-sprite", + ".sprite": "application/x-sprite", + ".spx": "audio/ogg", + ".src": "application/x-wais-source", + ".ssi": "text/x-server-parsed-html", + ".ssm": "application/streamingmedia", + ".sst": "application/vndms-pkicertstore", + ".step": "application/step", + ".stl": "application/sla", + ".stp": "application/step", + ".sv4cpio": "application/x-sv4cpio", + ".sv4crc": "application/x-sv4crc", + ".svf": "image/vnddwg", + ".svg": "image/svg+xml", + ".svr": "application/x-world", + ".swf": "application/x-shockwave-flash", + ".t": "application/x-troff", + ".talk": "text/x-speech", + ".tar": "application/x-tar", + ".tbk": "application/toolbook", + ".tcl": "text/x-scripttcl", + ".tcsh": "text/x-scripttcsh", + ".tex": "application/x-tex", + ".texi": "application/x-texinfo", + ".texinfo": "application/x-texinfo", + ".text": "text/plain", + ".tgz": "application/gnutar", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".tr": "application/x-troff", + ".tsi": "audio/tsp-audio", + ".tsp": "application/dsptype", + ".tsv": "text/tab-separated-values", + ".turbot": "image/florian", + ".txt": "text/plain", + ".uil": "text/x-uil", + ".uni": "text/uri-list", + ".unis": "text/uri-list", + ".unv": "application/i-deas", + ".uri": "text/uri-list", + ".uris": "text/uri-list", + ".ustar": "application/x-ustar", + ".uu": "text/x-uuencode", + ".uue": "text/x-uuencode", + ".vcd": "application/x-cdlink", + ".vcf": "text/x-vcard", + ".vcard": "text/x-vcard", + ".vcs": "text/x-vcalendar", + ".vda": "application/vda", + ".vdo": "video/vdo", + ".vew": "application/groupwise", + ".viv": "video/vivo", + ".vivo": "video/vivo", + ".vmd": "application/vocaltec-media-desc", + ".vmf": "application/vocaltec-media-file", + ".voc": "audio/voc", + ".vos": "video/vosaic", + ".vox": "audio/voxware", + ".vqe": "audio/x-twinvq-plugin", + ".vqf": "audio/x-twinvq", + ".vql": "audio/x-twinvq-plugin", + ".vrml": "application/x-vrml", + ".vrt": "x-world/x-vrt", + ".vsd": "application/x-visio", + ".vst": "application/x-visio", + ".vsw": "application/x-visio", + ".w60": "application/wordperfect60", + ".w61": "application/wordperfect61", + ".w6w": "application/msword", + ".wav": "audio/wav", + ".wb1": "application/x-qpro", + ".wbmp": "image/vnd.wap.wbmp", + ".web": "application/vndxara", + ".wiz": "application/msword", + ".wk1": "application/x-123", + ".wmf": "windows/metafile", + ".wml": "text/vnd.wap.wml", + ".wmlc": "application/vnd.wap.wmlc", + ".wmls": "text/vnd.wap.wmlscript", + ".wmlsc": "application/vnd.wap.wmlscriptc", + ".word": "application/msword", + ".wp5": "application/wordperfect", + ".wp6": "application/wordperfect", + ".wp": "application/wordperfect", + ".wpd": "application/wordperfect", + ".wq1": "application/x-lotus", + ".wri": "application/mswrite", + ".wrl": "application/x-world", + ".wrz": "model/vrml", + ".wsc": "text/scriplet", + ".wsrc": "application/x-wais-source", + ".wtk": "application/x-wintalk", + ".x-png": "image/png", + ".xbm": "image/x-xbitmap", + ".xdr": "video/x-amt-demorun", + ".xgz": "xgl/drawing", + ".xif": "image/vndxiff", + ".xl": "application/excel", + ".xla": "application/excel", + ".xlb": "application/excel", + ".xlc": "application/excel", + ".xld": "application/excel", + ".xlk": "application/excel", + ".xll": "application/excel", + ".xlm": "application/excel", + ".xls": "application/excel", + ".xlt": "application/excel", + ".xlv": "application/excel", + ".xlw": "application/excel", + ".xm": "audio/xm", + ".xml": "text/xml", + ".xmz": "xgl/movie", + ".xpix": "application/x-vndls-xpix", + ".xpm": "image/x-xpixmap", + ".xsr": "video/x-amt-showrun", + ".xwd": "image/x-xwd", + ".xyz": "chemical/x-pdb", + ".z": "application/x-compress", + ".zip": "application/zip", + ".zoo": "application/octet-stream", + ".zsh": "text/x-scriptzsh", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".docm": "application/vnd.ms-word.document.macroEnabled.12", + ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + ".dotm": "application/vnd.ms-word.template.macroEnabled.12", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12", + ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template", + ".xltm": "application/vnd.ms-excel.template.macroEnabled.12", + ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12", + ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12", + ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow", + ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12", + ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template", + ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12", + ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12", + ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide", + ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12", + ".thmx": "application/vnd.ms-officetheme", + ".onetoc": "application/onenote", + ".onetoc2": "application/onenote", + ".onetmp": "application/onenote", + ".onepkg": "application/onenote", + ".key": "application/x-iwork-keynote-sffkey", + ".kth": "application/x-iwork-keynote-sffkth", + ".nmbtemplate": "application/x-iwork-numbers-sfftemplate", + ".numbers": "application/x-iwork-numbers-sffnumbers", + ".pages": "application/x-iwork-pages-sffpages", + ".template": "application/x-iwork-pages-sfftemplate", + ".xpi": "application/x-xpinstall", + ".oex": "application/x-opera-extension", + ".mustache": "text/html", +} diff --git a/pkg/namespace.go b/pkg/namespace.go new file mode 100644 index 00000000..4952c9d5 --- /dev/null +++ b/pkg/namespace.go @@ -0,0 +1,396 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "strings" + + beecontext "github.com/astaxie/beego/context" +) + +type namespaceCond func(*beecontext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace struct { + prefix string + handlers *ControllerRegister +} + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + ns := &Namespace{ + prefix: prefix, + handlers: NewControllerRegister(), + } + for _, p := range params { + p(ns) + } + return ns +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + fn := func(ctx *beecontext.Context) { + if !cond(ctx) { + exception("405", ctx) + } + } + if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { + mr := new(FilterRouter) + mr.tree = NewTree() + mr.pattern = "*" + mr.filterFunc = fn + mr.tree.AddRouter("*", true) + n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...) + } else { + n.handlers.InsertFilter("*", BeforeRouter, fn) + } + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + var a int + if action == "before" { + a = BeforeRouter + } else if action == "after" { + a = FinishRouter + } + for _, f := range filter { + n.handlers.InsertFilter("*", a, f) + } + return n +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + n.handlers.Add(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + n.handlers.AddAuto(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + n.handlers.AddAutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + n.handlers.Get(rootpath, f) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + n.handlers.Post(rootpath, f) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + n.handlers.Delete(rootpath, f) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + n.handlers.Put(rootpath, f) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + n.handlers.Head(rootpath, f) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + n.handlers.Options(rootpath, f) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + n.handlers.Patch(rootpath, f) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + n.handlers.Any(rootpath, f) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + n.handlers.Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + n.handlers.Include(cList...) + return n +} + +// Namespace add nest Namespace +// usage: +//ns := beego.NewNamespace(“/v1”). +//Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +//) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + for _, ni := range ns { + for k, v := range ni.handlers.routers { + if _, ok := n.handlers.routers[k]; ok { + addPrefix(v, ni.prefix) + n.handlers.routers[k].AddTree(ni.prefix, v) + } else { + t := NewTree() + t.AddTree(ni.prefix, v) + addPrefix(t, ni.prefix) + n.handlers.routers[k] = t + } + } + if ni.handlers.enableFilter { + for pos, filterList := range ni.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(ni.prefix, mr.tree) + mr.tree = t + n.handlers.insertFilterRouter(pos, mr) + } + } + } + } + return n +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + for _, n := range nl { + for k, v := range n.handlers.routers { + if _, ok := BeeApp.Handlers.routers[k]; ok { + addPrefix(v, n.prefix) + BeeApp.Handlers.routers[k].AddTree(n.prefix, v) + } else { + t := NewTree() + t.AddTree(n.prefix, v) + addPrefix(t, n.prefix) + BeeApp.Handlers.routers[k] = t + } + } + if n.handlers.enableFilter { + for pos, filterList := range n.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(n.prefix, mr.tree) + mr.tree = t + BeeApp.Handlers.insertFilterRouter(pos, mr) + } + } + } + } +} + +func addPrefix(t *Tree, prefix string) { + for _, v := range t.fixrouters { + addPrefix(v, prefix) + } + if t.wildcard != nil { + addPrefix(t.wildcard, prefix) + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if !strings.HasPrefix(c.pattern, prefix) { + c.pattern = prefix + c.pattern + } + } + } +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(ns *Namespace) { + ns.Cond(cond) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("before", filterList...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("after", filterList...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.Include(cList...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(ns *Namespace) { + ns.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Get(rootpath, f) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Post(rootpath, f) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Head(rootpath, f) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Put(rootpath, f) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Delete(rootpath, f) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Any(rootpath, f) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Options(rootpath, f) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Patch(rootpath, f) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + n := NewNamespace(prefix, params...) + ns.Namespace(n) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + ns.Handler(rootpath, h) + } +} diff --git a/pkg/namespace_test.go b/pkg/namespace_test.go new file mode 100644 index 00000000..b3f20dff --- /dev/null +++ b/pkg/namespace_test.go @@ -0,0 +1,168 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/astaxie/beego/context" +) + +func TestNamespaceGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("v1_user")) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "v1_user" { + t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespacePost(t *testing.T) { + r, _ := http.NewRequest("POST", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNest(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order", func(ctx *context.Context) { + ctx.Output.Body([]byte("order")) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "order" { + t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNestParam(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/api/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Router("/api/list", &TestController{}, "*:List") + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestNamespaceFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Filter("before", func(ctx *context.Context) { + ctx.Output.Body([]byte("this is Filter")) + }). + Get("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "this is Filter" { + t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCond(t *testing.T) { + r, _ := http.NewRequest("GET", "/v2/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v2") + ns.Cond(func(ctx *context.Context) bool { + return ctx.Input.Domain() == "beego.me" + }). + AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Code != 405 { + t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) + } +} + +func TestNamespaceInside(t *testing.T) { + r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) + w := httptest.NewRecorder() + ns := NewNamespace("/v3", + NSAutoRouter(&TestController{}), + NSNamespace("/shop", + NSGet("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) + } +} diff --git a/pkg/parser.go b/pkg/parser.go new file mode 100644 index 00000000..3a311894 --- /dev/null +++ b/pkg/parser.go @@ -0,0 +1,591 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "unicode" + + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var globalRouterTemplate = `package {{.routersDir}} + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context/param"{{.globalimport}} +) + +func init() { +{{.globalinfo}} +} +` + +var ( + lastupdateFilename = "lastupdate.tmp" + commentFilename string + pkgLastupdate map[string]int64 + genInfoList map[string][]ControllerComments + + routerHooks = map[string]int{ + "beego.BeforeStatic": BeforeStatic, + "beego.BeforeRouter": BeforeRouter, + "beego.BeforeExec": BeforeExec, + "beego.AfterExec": AfterExec, + "beego.FinishRouter": FinishRouter, + } + + routerHooksMapping = map[int]string{ + BeforeStatic: "beego.BeforeStatic", + BeforeRouter: "beego.BeforeRouter", + BeforeExec: "beego.BeforeExec", + AfterExec: "beego.AfterExec", + FinishRouter: "beego.FinishRouter", + } +) + +const commentPrefix = "commentsRouter_" + +func init() { + pkgLastupdate = make(map[string]int64) +} + +func parserPkg(pkgRealpath, pkgpath string) error { + rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") + commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) + commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" + if !compareFile(pkgRealpath) { + logs.Info(pkgRealpath + " no changed") + return nil + } + genInfoList = make(map[string][]ControllerComments) + fileSet := token.NewFileSet() + astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { + name := info.Name() + return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") + }, parser.ParseComments) + + if err != nil { + return err + } + for _, pkg := range astPkgs { + for _, fl := range pkg.Files { + for _, d := range fl.Decls { + switch specDecl := d.(type) { + case *ast.FuncDecl: + if specDecl.Recv != nil { + exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser + if ok { + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) + } + } + } + } + } + } + genRouterCode(pkgRealpath) + savetoFile(pkgRealpath) + return nil +} + +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam + filters []parsedFilter + imports []parsedImport +} + +type parsedImport struct { + importPath string + importAlias string +} + +type parsedFilter struct { + pattern string + pos int + filter string + params []bool +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComments, err := parseComment(f.Doc.List) + if err != nil { + return err + } + for _, parsedComment := range parsedComments { + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + cc.FilterComments = buildFilters(parsedComment.filters) + cc.ImportComments = buildImports(parsedComment.imports) + genInfoList[key] = append(genInfoList[key], cc) + } + } + } + return nil +} + +func buildImports(pis []parsedImport) []*ControllerImportComments { + var importComments []*ControllerImportComments + + for _, pi := range pis { + importComments = append(importComments, &ControllerImportComments{ + ImportPath: pi.importPath, + ImportAlias: pi.importAlias, + }) + } + + return importComments +} + +func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { + var filterComments []*ControllerFilterComments + + for _, pf := range pfs { + var ( + returnOnOutput bool + resetParams bool + ) + + if len(pf.params) >= 1 { + returnOnOutput = pf.params[0] + } + + if len(pf.params) >= 2 { + resetParams = pf.params[1] + } + + filterComments = append(filterComments, &ControllerFilterComments{ + Filter: pf.filter, + Pattern: pf.pattern, + Pos: pf.pos, + ReturnOnOutput: returnOnOutput, + ResetParams: resetParams, + }) + } + + return filterComments +} + +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + for _, pName := range fparam.Names { + methodParam := buildMethodParam(fparam, pName.Name, pc) + result = append(result, methodParam) + } + } + return result +} + +func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { + pcs = []*parsedComment{} + params := map[string]parsedParam{} + filters := []parsedFilter{} + imports := []parsedImport{} + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.SplitN(pv[0], "=>", 2) + p.name = names[0] + funcParamName := p.name + if len(names) > 1 { + funcParamName = names[1] + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + params[funcParamName] = p + } + } + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Import") { + iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) + if len(iv) == 0 || len(iv) > 2 { + logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") + continue + } + + p := parsedImport{} + p.importPath = iv[0] + + if len(iv) == 2 { + p.importAlias = iv[1] + } + + imports = append(imports, p) + } + } + +filterLoop: + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Filter") { + fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) + if len(fv) < 3 { + logs.Error("Invalid @Filter format. Needs at least 3 parameters") + continue filterLoop + } + + p := parsedFilter{} + p.pattern = fv[0] + posName := fv[1] + if pos, exists := routerHooks[posName]; exists { + p.pos = pos + } else { + logs.Error("Invalid @Filter pos: ", posName) + continue filterLoop + } + + p.filter = fv[2] + fvParams := fv[3:] + for _, fvParam := range fvParams { + switch fvParam { + case "true": + p.params = append(p.params, true) + case "false": + p.params = append(p.params, false) + default: + logs.Error("Invalid @Filter param: ", fvParam) + continue filterLoop + } + } + + filters = append(filters, p) + } + } + + for _, c := range lines { + var pc = &parsedComment{} + pc.params = params + pc.filters = filters + pc.imports = imports + + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + pcs = append(pcs, pc) + } else { + return nil, errors.New("Router information is missing") + } + } + } + return +} + +// direct copy from bee\g_docs.go +// analysis params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range str { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + +func genRouterCode(pkgRealpath string) { + os.Mkdir(getRouterDir(pkgRealpath), 0755) + logs.Info("generate router from comments") + var ( + globalinfo string + globalimport string + sortKey []string + ) + for k := range genInfoList { + sortKey = append(sortKey, k) + } + sort.Strings(sortKey) + for _, k := range sortKey { + cList := genInfoList[k] + sort.Sort(ControllerCommentsSlice(cList)) + for _, c := range cList { + allmethod := "nil" + if len(c.AllowHTTPMethods) > 0 { + allmethod = "[]string{" + for _, m := range c.AllowHTTPMethods { + allmethod += `"` + m + `",` + } + allmethod = strings.TrimRight(allmethod, ",") + "}" + } + + params := "nil" + if len(c.Params) > 0 { + params = "[]map[string]string{" + for _, p := range c.Params { + for k, v := range p { + params = params + `map[string]string{` + k + `:"` + v + `"},` + } + } + params = strings.TrimRight(params, ",") + "}" + } + + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" + + imports := "" + if len(c.ImportComments) > 0 { + for _, i := range c.ImportComments { + var s string + if i.ImportAlias != "" { + s = fmt.Sprintf(` + %s "%s"`, i.ImportAlias, i.ImportPath) + } else { + s = fmt.Sprintf(` + "%s"`, i.ImportPath) + } + if !strings.Contains(globalimport, s) { + imports += s + } + } + } + + filters := "" + if len(c.FilterComments) > 0 { + for _, f := range c.FilterComments { + filters += fmt.Sprintf(` &beego.ControllerFilter{ + Pattern: "%s", + Pos: %s, + Filter: %s, + ReturnOnOutput: %v, + ResetParams: %v, + },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) + } + } + + if filters == "" { + filters = "nil" + } else { + filters = fmt.Sprintf(`[]*beego.ControllerFilter{ +%s + }`, filters) + } + + globalimport += imports + + globalinfo = globalinfo + ` + beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], + beego.ControllerComments{ + Method: "` + strings.TrimSpace(c.Method) + `", + ` + `Router: "` + c.Router + `"` + `, + AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, + Filters: ` + filters + `, + Params: ` + params + `}) +` + } + } + + if globalinfo != "" { + f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) + if err != nil { + panic(err) + } + defer f.Close() + + routersDir := AppConfig.DefaultString("routersdir", "routers") + content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) + content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) + f.WriteString(content) + } +} + +func compareFile(pkgRealpath string) bool { + if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { + return true + } + if utils.FileExists(lastupdateFilename) { + content, err := ioutil.ReadFile(lastupdateFilename) + if err != nil { + return true + } + json.Unmarshal(content, &pkgLastupdate) + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return true + } + if v, ok := pkgLastupdate[pkgRealpath]; ok { + if lastupdate <= v { + return false + } + } + } + return true +} + +func savetoFile(pkgRealpath string) { + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return + } + pkgLastupdate[pkgRealpath] = lastupdate + d, err := json.Marshal(pkgLastupdate) + if err != nil { + return + } + ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) +} + +func getpathTime(pkgRealpath string) (lastupdate int64, err error) { + fl, err := ioutil.ReadDir(pkgRealpath) + if err != nil { + return lastupdate, err + } + for _, f := range fl { + if lastupdate < f.ModTime().UnixNano() { + lastupdate = f.ModTime().UnixNano() + } + } + return lastupdate, nil +} + +func getRouterDir(pkgRealpath string) string { + dir := filepath.Dir(pkgRealpath) + for { + routersDir := AppConfig.DefaultString("routersdir", "routers") + d := filepath.Join(dir, routersDir) + if utils.FileExists(d) { + return d + } + + if r, _ := filepath.Rel(dir, AppPath); r == "." { + return d + } + // Parent dir. + dir = filepath.Dir(dir) + } +} diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..10e25f3f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth.go @@ -0,0 +1,165 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "sort" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret func(string) string + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + ft := func(aid string) string { + if aid == appid { + return appkey + } + return "" + } + return APISecretAuth(ft, 300) +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + return func(ctx *context.Context) { + if ctx.Input.Query("appid") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: appid") + return + } + appsecret := f(ctx.Input.Query("appid")) + if appsecret == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("not exist this appid") + return + } + if ctx.Input.Query("signature") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: signature") + return + } + if ctx.Input.Query("timestamp") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: timestamp") + return + } + u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) + if err != nil { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") + return + } + t := time.Now() + if t.Sub(u).Seconds() > float64(timeout) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timeout! the request time is long ago, please try again") + return + } + if ctx.Input.Query("signature") != + Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("auth failed") + } + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { + var b bytes.Buffer + keys := make([]string, len(params)) + pa := make(map[string]string) + for k, v := range params { + pa[k] = v[0] + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, key := range keys { + if key == "signature" { + continue + } + + val := pa[key] + if key != "" && val != "" { + b.WriteString(key) + b.WriteString(val) + } + } + + stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) + + sha256 := sha256.New + hash := hmac.New(sha256, []byte(appsecret)) + hash.Write([]byte(stringToSign)) + return base64.StdEncoding.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/plugins/apiauth/apiauth_test.go b/pkg/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go new file mode 100644 index 00000000..c478044a --- /dev/null +++ b/pkg/plugins/auth/basic.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "encoding/base64" + "net/http" + "strings" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +var defaultRealm = "Authorization Required" + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + secrets := func(user, pass string) bool { + return user == username && pass == password + } + return NewBasicAuthenticator(secrets, defaultRealm) +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuth{Secrets: secrets, Realm: Realm} + if username := a.CheckAuth(ctx.Request); username == "" { + a.RequireAuth(ctx.ResponseWriter, ctx.Request) + } + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider func(user, pass string) bool + +// BasicAuth store the SecretProvider and Realm +type BasicAuth struct { + Secrets SecretProvider + Realm string +} + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 || s[0] != "Basic" { + return "" + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return "" + } + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "" + } + + if a.Secrets(pair[0], pair[1]) { + return pair[0] + } + return "" +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) + w.WriteHeader(401) + w.Write([]byte("401 Unauthorized\n")) +} diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go new file mode 100644 index 00000000..9dc0db76 --- /dev/null +++ b/pkg/plugins/authz/authz.go @@ -0,0 +1,86 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/casbin/casbin" + "net/http" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuthorizer{enforcer: e} + + if !a.CheckPermission(ctx.Request) { + a.RequirePermission(ctx.ResponseWriter) + } + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer struct { + enforcer *casbin.Enforcer +} + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + username, _, _ := r.BasicAuth() + return username +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + user := a.GetUserName(r) + method := r.Method + path := r.URL.Path + return a.enforcer.Enforce(user, path, method) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + w.WriteHeader(403) + w.Write([]byte("403 Forbidden\n")) +} diff --git a/pkg/plugins/authz/authz_model.conf b/pkg/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/plugins/authz/authz_policy.csv b/pkg/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go new file mode 100644 index 00000000..49aed84c --- /dev/null +++ b/pkg/plugins/authz/authz_test.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/plugins/auth" + "github.com/casbin/casbin" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go new file mode 100644 index 00000000..45c327ab --- /dev/null +++ b/pkg/plugins/cors/cors.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +const ( + headerAllowOrigin = "Access-Control-Allow-Origin" + headerAllowCredentials = "Access-Control-Allow-Credentials" + headerAllowHeaders = "Access-Control-Allow-Headers" + headerAllowMethods = "Access-Control-Allow-Methods" + headerExposeHeaders = "Access-Control-Expose-Headers" + headerMaxAge = "Access-Control-Max-Age" + + headerOrigin = "Origin" + headerRequestMethod = "Access-Control-Request-Method" + headerRequestHeaders = "Access-Control-Request-Headers" +) + +var ( + defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} + // Regex patterns are generated from AllowOrigins. These are used and generated internally. + allowOriginPatterns = []string{} +) + +// Options represents Access Control options. +type Options struct { + // If set, all origins are allowed. + AllowAllOrigins bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string + // If set, allows to share auth credentials such as cookies. + AllowCredentials bool + // A list of allowed HTTP methods. + AllowMethods []string + // A list of allowed HTTP headers. + AllowHeaders []string + // A list of exposed HTTP headers. + ExposeHeaders []string + // Max age of the CORS headers. + MaxAge time.Duration +} + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + headers = make(map[string]string) + // if origin is not allowed, don't extend the headers + // with CORS headers. + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allow credentials + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + + // add allow methods + if len(o.AllowMethods) > 0 { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + } + + // add allow headers + if len(o.AllowHeaders) > 0 { + headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") + } + + // add exposed header + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + headers = make(map[string]string) + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + // verify if requested method is allowed + for _, method := range o.AllowMethods { + if method == rMethod { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + break + } + } + + // verify if requested headers are allowed + var allowed []string + for _, rHeader := range strings.Split(rHeaders, ",") { + rHeader = strings.TrimSpace(rHeader) + lookupLoop: + for _, allowedHeader := range o.AllowHeaders { + if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + allowed = append(allowed, rHeader) + break lookupLoop + } + } + } + + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allowed headers + if len(allowed) > 0 { + headers[headerAllowHeaders] = strings.Join(allowed, ",") + } + + // add exposed headers + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) (allowed bool) { + for _, pattern := range allowOriginPatterns { + allowed, _ = regexp.MatchString(pattern, origin) + if allowed { + return + } + } + return +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + // Allow default headers if nothing is specified. + if len(opts.AllowHeaders) == 0 { + opts.AllowHeaders = defaultAllowHeaders + } + + for _, origin := range opts.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") + } + + return func(ctx *context.Context) { + var ( + origin = ctx.Input.Header(headerOrigin) + requestedMethod = ctx.Input.Header(headerRequestMethod) + requestedHeaders = ctx.Input.Header(headerRequestHeaders) + // additional headers to be added + // to the response. + headers map[string]string + ) + + if ctx.Input.Method() == "OPTIONS" && + (requestedMethod != "" || requestedHeaders != "") { + headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) + for key, value := range headers { + ctx.Output.Header(key, value) + } + ctx.ResponseWriter.WriteHeader(http.StatusOK) + return + } + headers = opts.Header(origin) + + for key, value := range headers { + ctx.Output.Header(key, value) + } + } +} diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go new file mode 100644 index 00000000..34039143 --- /dev/null +++ b/pkg/plugins/cors/cors_test.go @@ -0,0 +1,253 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cors + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { + *httptest.ResponseRecorder + savedHeaderMap http.Header +} + +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} +} + +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { + gr.ResponseRecorder.WriteHeader(code) + gr.savedHeaderMap = gr.ResponseRecorder.Header() +} + +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { + if gr.savedHeaderMap != nil { + // headers were written. clone so we don't get updates + clone := make(http.Header) + for k, v := range gr.savedHeaderMap { + clone[k] = v + } + return clone + } + return gr.ResponseRecorder.Header() +} + +func Test_AllowAll(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { + t.Errorf("Allow-Origin header should be *") + } +} + +func Test_AllowRegexMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://bar.foo.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != origin { + t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) + } +} + +func Test_AllowRegexNoMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://ww.foo.com.evil.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != "" { + t.Errorf("Allow-Origin header should not exist, found %v", headerValue) + } +} + +func Test_OtherHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + ExposeHeaders: []string{"Content-Length", "Hello"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) + methodsVal := recorder.HeaderMap.Get(headerAllowMethods) + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) + maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) + + if credentialsVal != "true" { + t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) + } + + if methodsVal != "PATCH,GET" { + t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) + } + + if headersVal != "Origin,X-whatever" { + t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) + } + + if exposedHeadersVal != "Content-Length,Hello" { + t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) + } + + if maxAgeVal != "300" { + t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) + } +} + +func Test_DefaultAllowHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + if headersVal != "Origin,Accept,Content-Type,Authorization" { + t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) + } +} + +func Test_Preflight(t *testing.T) { + recorder := NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowMethods: []string{"PUT", "PATCH"}, + AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, + })) + + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + r, _ := http.NewRequest("OPTIONS", "/foo", nil) + r.Header.Add(headerRequestMethod, "PUT") + r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") + handler.ServeHTTP(recorder, r) + + headers := recorder.Header() + methodsVal := headers.Get(headerAllowMethods) + headersVal := headers.Get(headerAllowHeaders) + originVal := headers.Get(headerAllowOrigin) + + if methodsVal != "PUT,PATCH" { + t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) + } + + if !strings.Contains(headersVal, "X-whatever") { + t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) + } + + if !strings.Contains(headersVal, "x-casesensitive") { + t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) + } + + if originVal != "*" { + t.Errorf("Allow-Origin is expected to be *, found %v", originVal) + } + + if recorder.Code != http.StatusOK { + t.Errorf("Status code is expected to be 200, found %d", recorder.Code) + } +} + +func Benchmark_WithoutCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} + +func Benchmark_WithCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} diff --git a/pkg/policy.go b/pkg/policy.go new file mode 100644 index 00000000..ab23f927 --- /dev/null +++ b/pkg/policy.go @@ -0,0 +1,97 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + var urlPath = cont.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := cont.Input.Method() + isWildcard := false + // Find policy for current method + t, ok := p.policies[httpMethod] + // If not found - find policy for whole controller + if !ok { + t, ok = p.policies["*"] + isWildcard = true + } + if ok { + runObjects := t.Match(urlPath, cont) + if r, ok := runObjects.([]PolicyFunc); ok { + return r + } else if !isWildcard { + // If no policies found and we checked not for "*" method - try to find it + t, ok = p.policies["*"] + if ok { + runObjects = t.Match(urlPath, cont) + if r, ok = runObjects.([]PolicyFunc); ok { + return r + } + } + } + } + return nil +} + +func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { + method = strings.ToUpper(method) + p.enablePolicy = true + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.policies[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.policies[method] = t + } +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + BeeApp.Handlers.addToPolicy(method, pattern, policy...) +} + +// Find policies and execute if were found +func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { + if !p.enablePolicy { + return false + } + // Find Policy for method + policyList := p.FindPolicy(cont) + if len(policyList) > 0 { + // Run policies + for _, runPolicy := range policyList { + runPolicy(cont) + if cont.ResponseWriter.Started { + return true + } + } + return false + } + return false +} diff --git a/pkg/router.go b/pkg/router.go new file mode 100644 index 00000000..6a8ac6f7 --- /dev/null +++ b/pkg/router.go @@ -0,0 +1,1052 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "reflect" + "strconv" + "strings" + "sync" + "time" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// default filter execution points +const ( + BeforeStatic = iota + BeforeRouter + BeforeExec + AfterExec + FinishRouter +) + +const ( + routerTypeBeego = iota + routerTypeRESTFul + routerTypeHandler +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = map[string]bool{ + "GET": true, + "POST": true, + "PUT": true, + "DELETE": true, + "PATCH": true, + "OPTIONS": true, + "HEAD": true, + "TRACE": true, + "CONNECT": true, + "MKCOL": true, + "COPY": true, + "MOVE": true, + "PROPFIND": true, + "PROPPATCH": true, + "LOCK": true, + "UNLOCK": true, + } + // these beego.Controller's methods shouldn't reflect to AutoRouter + exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", + "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", + "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", + "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", + "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", + "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", + "GetControllerAndAction", "ServeFormatted"} + + urlPlaceholder = "{{placeholder}}" + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &logFilter{} +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +// default log filter static file will not show +type logFilter struct { +} + +func (l *logFilter) Filter(ctx *beecontext.Context) bool { + requestPath := path.Clean(ctx.Request.URL.Path) + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + return true + } + for prefix := range BConfig.WebConfig.StaticDir { + if strings.HasPrefix(requestPath, prefix) { + return true + } + } + return false +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + exceptMethod = append(exceptMethod, action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo struct { + pattern string + controllerType reflect.Type + methods map[string]string + handler http.Handler + runFunction FilterFunc + routerType int + initialize func() ControllerInterface + methodParams []*param.MethodParam +} + +func (c *ControllerInfo) GetPattern() string { + return c.pattern +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister struct { + routers map[string]*Tree + enablePolicy bool + policies map[string]*Tree + enableFilter bool + filters [FinishRouter + 1][]*FilterRouter + pool sync.Pool +} + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return &ControllerRegister{ + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), + pool: sync.Pool{ + New: func() interface{} { + return beecontext.NewContext() + }, + }, + } +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + p.addWithMethodParams(pattern, c, nil, mappingMethods...) +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + methods := make(map[string]string) + if len(mappingMethods) > 0 { + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + } else { + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) + } + } else { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) + } + } + } + } + + route := &ControllerInfo{} + route.pattern = pattern + route.methods = methods + route.routerType = routerTypeBeego + route.controllerType = t + route.initialize = func() ControllerInterface { + vc := reflect.New(route.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + + elemVal := reflect.ValueOf(c).Elem() + elemType := reflect.TypeOf(c).Elem() + execElem := reflect.ValueOf(execController).Elem() + + numOfFields := elemVal.NumField() + for i := 0; i < numOfFields; i++ { + fieldType := elemType.Field(i) + elemField := execElem.FieldByName(fieldType.Name) + if elemField.CanSet() { + fieldVal := elemVal.Field(i) + elemField.Set(fieldVal) + } + } + + return execController + } + + route.methodParams = methodParams + if len(methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } + } +} + +func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.routers[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.routers[method] = t + } +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + if BConfig.RunMode == DEV { + skip := make(map[string]bool, 10) + wgopath := utils.GetGOPATHs() + go111module := os.Getenv(`GO111MODULE`) + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + // for go modules + if go111module == `on` { + pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) + if utils.FileExists(pkgpath) { + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } else { + if len(wgopath) == 0 { + panic("you are in dev mode. So please set gopath") + } + pkgpath := "" + for _, wg := range wgopath { + wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) + if utils.FileExists(wg) { + pkgpath = wg + break + } + } + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } + } + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + key := t.PkgPath() + ":" + t.Name() + if comm, ok := GlobalControllerRouter[key]; ok { + for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + } + + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + } + } + } +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return p.pool.Get().(*beecontext.Context) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + p.pool.Put(ctx) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + p.AddMethod("get", pattern, f) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + p.AddMethod("post", pattern, f) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + p.AddMethod("put", pattern, f) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + p.AddMethod("delete", pattern, f) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + p.AddMethod("head", pattern, f) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + p.AddMethod("patch", pattern, f) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + p.AddMethod("options", pattern, f) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + p.AddMethod("*", pattern, f) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.runFunction = f + methods := make(map[string]string) + if method == "*" { + for val := range HTTPMETHOD { + methods[val] = val + } + } else { + methods[method] = method + } + route.methods = methods + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.handler = h + if len(options) > 0 { + if _, ok := options[0].(bool); ok { + pattern = path.Join(pattern, "?:all(.*)") + } + } + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + p.AddAutoPrefix("/", c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + controllerName := strings.TrimSuffix(ct.Name(), "Controller") + for i := 0; i < rt.NumMethod(); i++ { + if !utils.InSlice(rt.Method(i).Name, exceptMethod) { + route := &ControllerInfo{} + route.routerType = routerTypeBeego + route.methods = map[string]string{"*": rt.Method(i).Name} + route.controllerType = ct + pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") + patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) + patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) + route.pattern = pattern + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + p.addToRouter(m, patternInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) + } + } + } +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !BConfig.RouterCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return p.insertFilterRouter(pos, mr) +} + +// add Filter into +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + return errors.New("can not find your filter position") + } + p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) + return nil +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + paths := strings.Split(endpoint, ".") + if len(paths) <= 1 { + logs.Warn("urlfor endpoint must like path.controller.method") + return "" + } + if len(values)%2 != 0 { + logs.Warn("urlfor params must key-value pair") + return "" + } + params := make(map[string]string) + if len(values) > 0 { + key := "" + for k, v := range values { + if k%2 == 0 { + key = fmt.Sprint(v) + } else { + params[key] = fmt.Sprint(v) + } + } + } + controllerName := strings.Join(paths[:len(paths)-1], "/") + methodName := paths[len(paths)-1] + for m, t := range p.routers { + ok, url := p.getURL(t, "/", controllerName, methodName, params, m) + if ok { + return url + } + } + return "" +} + +func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { + for _, subtree := range t.fixrouters { + u := path.Join(url, subtree.prefix) + ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + if t.wildcard != nil { + u := path.Join(url, urlPlaceholder) + ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if c.routerType == routerTypeBeego && + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + find := false + if HTTPMETHOD[strings.ToUpper(methodName)] { + if len(c.methods) == 0 { + find = true + } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { + find = true + } else if m, ok = c.methods["*"]; ok && m == methodName { + find = true + } + } + if !find { + for m, md := range c.methods { + if (m == "*" || m == httpMethod) && md == methodName { + find = true + } + } + } + if find { + if l.regexps == nil { + if len(l.wildcards) == 0 { + return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) + } + if len(l.wildcards) == 1 { + if v, ok := params[l.wildcards[0]]; ok { + delete(params, l.wildcards[0]) + return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) + } + return false, "" + } + if len(l.wildcards) == 3 && l.wildcards[0] == "." { + if p, ok := params[":path"]; ok { + if e, isok := params[":ext"]; isok { + delete(params, ":path") + delete(params, ":ext") + return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) + } + } + } + canSkip := false + for _, v := range l.wildcards { + if v == ":" { + canSkip = true + continue + } + if u, ok := params[v]; ok { + delete(params, v) + url = strings.Replace(url, urlPlaceholder, u, 1) + } else { + if canSkip { + canSkip = false + continue + } + return false, "" + } + } + return true, url + toURL(params) + } + var i int + var startReg bool + regURL := "" + for _, v := range strings.Trim(l.regexps.String(), "^$") { + if v == '(' { + startReg = true + continue + } else if v == ')' { + startReg = false + if v, ok := params[l.wildcards[i]]; ok { + delete(params, l.wildcards[i]) + regURL = regURL + v + i++ + } else { + break + } + } else if !startReg { + regURL = string(append([]rune(regURL), v)) + } + } + if l.regexps.MatchString(regURL) { + ps := strings.Split(regURL, "/") + for _, p := range ps { + url = strings.Replace(url, urlPlaceholder, p, 1) + } + return true, url + toURL(params) + } + } + } + } + } + + return false, "" +} + +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + var preFilterParams map[string]string + for _, filterR := range p.filters[pos] { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if filterR.resetParams { + preFilterParams = context.Input.Params() + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + if filterR.resetParams { + context.Input.ResetParams() + for k, v := range preFilterParams { + context.Input.SetParam(k, v) + } + } + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + } + return false +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + startTime := time.Now() + var ( + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + ) + context := p.GetContext() + + context.Reset(rw, r) + + defer p.GiveBackContext(context) + if BConfig.RecoverFunc != nil { + defer BConfig.RecoverFunc(context) + } + + context.Output.EnableGzip = BConfig.EnableGzip + + if BConfig.RunMode == DEV { + context.Output.Header("Server", BConfig.ServerName) + } + + var urlPath = r.URL.Path + + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + + // filter wrong http method + if !HTTPMETHOD[r.Method] { + exception("405", context) + goto Admin + } + + // filter for static file + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + goto Admin + } + + serverStaticRouter(context) + + if context.ResponseWriter.Started { + findRouter = true + goto Admin + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + if BConfig.CopyRequestBody && !context.Input.IsUpload() { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) + if r.ContentLength > BConfig.MaxMemory { + logs.Error(errors.New("payload too large")) + exception("413", context) + goto Admin + } + context.Input.CopyBody(BConfig.MaxMemory) + } + context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + } + + // session init + if BConfig.WebConfig.Session.SessionOn { + var err error + context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + if err != nil { + logs.Error(err) + exception("503", context) + goto Admin + } + defer func() { + if context.Input.CruSession != nil { + context.Input.CruSession.SessionRelease(rw) + } + }() + } + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + goto Admin + } + // User can define RunController and RunMethod in filter + if context.Input.RunController != nil && context.Input.RunMethod != "" { + findRouter = true + runMethod = context.Input.RunMethod + runRouter = context.Input.RunController + } else { + routerInfo, findRouter = p.FindRouter(context) + } + + // if no matches to url, throw a not found exception + if !findRouter { + exception("404", context) + goto Admin + } + if splat := context.Input.Param(":splat"); splat != "" { + for k, v := range strings.Split(splat, "/") { + context.Input.SetParam(strconv.Itoa(k), v) + } + } + + if routerInfo != nil { + // store router pattern into context + context.Input.SetData("RouterPattern", routerInfo.pattern) + } + + // execute middleware filters + if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + goto Admin + } + + // check policies + if p.execPolicy(context, urlPath) { + goto Admin + } + + if routerInfo != nil { + if routerInfo.routerType == routerTypeRESTFul { + if _, ok := routerInfo.methods[r.Method]; ok { + isRunnable = true + routerInfo.runFunction(context) + } else { + exception("405", context) + goto Admin + } + } else if routerInfo.routerType == routerTypeHandler { + isRunnable = true + routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + } else { + runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams + method := r.Method + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + method = http.MethodPut + } + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + method = http.MethodDelete + } + if m, ok := routerInfo.methods[method]; ok { + runMethod = m + } else if m, ok = routerInfo.methods["*"]; ok { + runMethod = m + } else { + runMethod = method + } + } + } + + // also defined runRouter & runMethod from filter + if !isRunnable { + // Invoke the request handler + var execController ControllerInterface + if routerInfo != nil && routerInfo.initialize != nil { + execController = routerInfo.initialize() + } else { + vc := reflect.New(runRouter) + var ok bool + execController, ok = vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + } + + // call the controller init function + execController.Init(context, runRouter.Name(), runMethod, execController) + + // call prepare function + execController.Prepare() + + // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf + if BConfig.WebConfig.EnableXSRF { + execController.XSRFToken() + if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || + (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + execController.CheckXSRFCookie() + } + } + + execController.URLMapping() + + if !context.ResponseWriter.Started { + // exec main logic + switch runMethod { + case http.MethodGet: + execController.Get() + case http.MethodPost: + execController.Post() + case http.MethodDelete: + execController.Delete() + case http.MethodPut: + execController.Put() + case http.MethodHead: + execController.Head() + case http.MethodPatch: + execController.Patch() + case http.MethodOptions: + execController.Options() + case http.MethodTrace: + execController.Trace() + default: + if !execController.HandlerFunc(runMethod) { + vc := reflect.ValueOf(execController) + method := vc.MethodByName(runMethod) + in := param.ConvertParams(methodParams, method.Type(), context) + out := method.Call(in) + + // For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(context, execController, out) + } + } + } + + // render template + if !context.ResponseWriter.Started && context.Output.Status == 0 { + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + logs.Error(err) + } + } + } + } + + // finish all runRouter. release resource + execController.Finish() + } + + // execute middleware filters + if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + goto Admin + } + + if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + goto Admin + } + +Admin: + // admin module record QPS + + statusCode := context.ResponseWriter.Status + if statusCode == 0 { + statusCode = 200 + } + + LogAccess(context, &startTime, statusCode) + + timeDur := time.Since(startTime) + context.ResponseWriter.Elapsed = timeDur + if BConfig.Listen.EnableAdmin { + pattern := "" + if routerInfo != nil { + pattern = routerInfo.pattern + } + + if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { + routerName := "" + if runRouter != nil { + routerName = runRouter.Name() + } + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) + } + } + + if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { + match := map[bool]string{true: "match", false: "nomatch"} + devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", + context.Input.IP(), + logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), + timeDur.String(), + match[findRouter], + logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), + r.URL.Path) + if routerInfo != nil { + devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) + } + + logs.Debug(devInfo) + } + // Call WriteHeader if status code has been set changed + if context.Output.Status != 0 { + context.ResponseWriter.WriteHeader(context.Output.Status) + } +} + +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + // looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + var urlPath = context.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := context.Input.Method() + if t, ok := p.routers[httpMethod]; ok { + runObject := t.Match(urlPath, context) + if r, ok := runObject.(*ControllerInfo); ok { + return r, true + } + } + return +} + +func toURL(params map[string]string) string { + if len(params) == 0 { + return "" + } + u := "?" + for k, v := range params { + u += k + "=" + v + "&" + } + return strings.TrimRight(u, "&") +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + // Skip logging if AccessLogs config is false + if !BConfig.Log.AccessLogs { + return + } + // Skip logging static requests unless EnableStaticLogs config is true + if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: r.ContentLength, + } + logs.AccessLog(record, BConfig.Log.AccessLogsFormat) +} diff --git a/pkg/router_test.go b/pkg/router_test.go new file mode 100644 index 00000000..8ec7927a --- /dev/null +++ b/pkg/router_test.go @@ -0,0 +1,732 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +type TestController struct { + Controller +} + +func (tc *TestController) Get() { + tc.Data["Username"] = "astaxie" + tc.Ctx.Output.Body([]byte("ok")) +} + +func (tc *TestController) Post() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) Param() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) List() { + tc.Ctx.Output.Body([]byte("i am list")) +} + +func (tc *TestController) Params() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) +} + +func (tc *TestController) Myext() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) +} + +func (tc *TestController) GetURL() { + tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) +} + +func (tc *TestController) GetParams() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + + tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) +} + +func (tc *TestController) GetManyRouter() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) +} + +func (tc *TestController) GetEmptyBody() { + var res []byte + tc.Ctx.Output.Body(res) +} + +type JSONController struct { + Controller +} + +func (jc *JSONController) Prepare() { + jc.Data["json"] = "prepare" + jc.ServeJSON(true) +} + +func (jc *JSONController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/person/:last/:first", &TestController{}, "*:Param") + if a := handler.URLFor("TestController.List"); a != "/api/list" { + logs.Info(a) + t.Errorf("TestController.List must equal to /api/list") + } + if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { + t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) + } +} + +func TestUrlFor3(t *testing.T) { + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) + } + if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) + } +} + +func TestUrlFor2(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") + handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) + if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { + logs.Info(handler.URLFor("TestController.GetURL")) + t.Errorf("TestController.List must equal to /v1/astaxie/edit") + } + + if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za/cms_12_123.html" { + logs.Info(handler.URLFor("TestController.List")) + t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") + } + if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za_cms/ttt_12_123.html" { + logs.Info(handler.URLFor("TestController.Param")) + t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") + } + if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", + ":title", "aaaa", ":entid", "aaaa") != + "/1111/11/aaaa/aaaa" { + logs.Info(handler.URLFor("TestController.Get")) + t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") + } +} + +func TestUserFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/api/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestPostFunc(t *testing.T) { + r, _ := http.NewRequest("POST", "/astaxie", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/:name", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "astaxie" { + t.Errorf("post func should astaxie") + } +} + +func TestAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFunc2(t *testing.T) { + r, _ := http.NewRequest("GET", "/Test/List", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFuncParams(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "20091112" { + t.Errorf("user define func can't run") + } +} + +func TestAutoExtFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/myext.json", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "json" { + t.Errorf("user define func can't run") + } +} + +func TestRouteOk(t *testing.T) { + + r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != "anderson+thomas+kungfu" { + t.Errorf("url param set to [%s];", body) + } +} + +func TestManyRoute(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego32-12.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.ServeHTTP(w, r) + + body := w.Body.String() + + if body != "3212" { + t.Errorf("url param set to [%s];", body) + } +} + +// Test for issue #1669 +func TestEmptyResponse(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego-empty.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.ServeHTTP(w, r) + + if body := w.Body.String(); body != "" { + t.Error("want empty body") + } +} + +func TestNotFound(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound) + } +} + +// TestStatic tests the ability to serve static +// content from the filesystem +func TestStatic(t *testing.T) { + r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != 404 { + t.Errorf("handler.Static failed to serve file") + } +} + +func TestPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/json/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/json/list", &JSONController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != `"prepare"` { + t.Errorf(w.Body.String() + "user define func can't run") + } +} + +func TestAutoPrefix(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAutoPrefix("/admin", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestAutoPrefix can't run") + } +} + +func TestRouterGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("Get userlist")) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "Get userlist" { + t.Errorf("TestRouterGet can't run") + } +} + +func TestRouterPost(t *testing.T) { + r, _ := http.NewRequest("POST", "/user/123", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestRouterPost can't run") + } +} + +func sayhello(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sayhello")) +} + +func TestRouterHandler(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello)) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +func TestRouterHandlerAll(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +// +// Benchmarks NewApp: +// + +func beegoFilterFunc(ctx *context.Context) { + ctx.WriteString("hello") +} + +type AdminController struct { + Controller +} + +func (a *AdminController) Get() { + a.Ctx.WriteString("hello") +} + +func TestRouterFunc(t *testing.T) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + mux.Post("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + mux.ServeHTTP(rw, r) + if rw.Body.String() != "hello" { + t.Errorf("TestRouterFunc can't run") + } +} + +func BenchmarkFunc(b *testing.B) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func BenchmarkController(b *testing.B) { + mux := NewControllerRegister() + mux.Add("/action", &AdminController{}) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) { + request, _ := http.NewRequest(method, path, nil) + recorder := httptest.NewRecorder() + + return recorder, request +} + +// Expectation: A Filter with the correct configuration should be created given +// specific parameters. +func TestInsertFilter(t *testing.T) { + testName := "TestInsertFilter" + + mux := NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + if !mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing no variadic params should set returnOnOutput to true", + testName) + } + if mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing no variadic params should set resetParams to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + if mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing false as 1st variadic param should set returnOnOutput to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + if !mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing true as 2nd variadic param should set resetParams to true", + testName) + } +} + +// Expectation: the second variadic arg should cause the execution of the filter +// to preserve the parameters from before its execution. +func TestParamResetFilter(t *testing.T) { + testName := "TestParamResetFilter" + route := "/beego/*" // splat + path := "/beego/routes/routes" + + mux := NewControllerRegister() + + mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + + mux.Get(route, beegoHandleResetParams) + + rw, r := testRequest("GET", path) + mux.ServeHTTP(rw, r) + + // The two functions, `beegoResetParams` and `beegoHandleResetParams` add + // a response header of `Splat`. The expectation here is that that Header + // value should match what the _request's_ router set, not the filter's. + + headers := rw.Result().Header + if len(headers["Splat"]) != 1 { + t.Errorf( + "%s: There was an error in the test. Splat param not set in Header", + testName) + } + if headers["Splat"][0] != "routes/routes" { + t.Errorf( + "%s: expected `:splat` param to be [routes/routes] but it was [%s]", + testName, headers["Splat"][0]) + } +} + +// Execution point: BeforeRouter +// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle +func TestFilterBeforeRouter(t *testing.T) { + testName := "TestFilterBeforeRouter" + url := "/beforeRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeRouter1") { + t.Errorf(testName + " BeforeRouter did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeRouter did not return properly") + } +} + +// Execution point: BeforeExec +// expectation: only BeforeExec function is executed, match as router determines route only +func TestFilterBeforeExec(t *testing.T) { + testName := "TestFilterBeforeExec" + url := "/beforeExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeExec1") { + t.Errorf(testName + " BeforeExec did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeExec did not return properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } +} + +// Execution point: AfterExec +// expectation: only AfterExec function is executed, match as router handles +func TestFilterAfterExec(t *testing.T) { + testName := "TestFilterAfterExec" + url := "/afterExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only FinishRouter function is executed, match as router handles +func TestFilterFinishRouter(t *testing.T) { + testName := "TestFilterFinishRouter" + url := "/finishRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only first FinishRouter function is executed, match as router handles +func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { + testName := "TestFilterFinishRouterMultiFirstOnly" + url := "/finishRouterMultiFirstOnly" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + // not expected in body + if strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did run") + } +} + +// Execution point: FinishRouter +// expectation: both FinishRouter functions execute, match as router handles +func TestFilterFinishRouterMulti(t *testing.T) { + testName := "TestFilterFinishRouterMulti" + url := "/finishRouterMulti" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if !strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did not run properly") + } +} + +func beegoFilterNoOutput(ctx *context.Context) { +} + +func beegoBeforeRouter1(ctx *context.Context) { + ctx.WriteString("|BeforeRouter1") +} + +func beegoBeforeExec1(ctx *context.Context) { + ctx.WriteString("|BeforeExec1") +} + +func beegoAfterExec1(ctx *context.Context) { + ctx.WriteString("|AfterExec1") +} + +func beegoFinishRouter1(ctx *context.Context) { + ctx.WriteString("|FinishRouter1") +} + +func beegoFinishRouter2(ctx *context.Context) { + ctx.WriteString("|FinishRouter2") +} + +func beegoResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +func beegoHandleResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +// YAML +type YAMLController struct { + Controller +} + +func (jc *YAMLController) Prepare() { + jc.Data["yaml"] = "prepare" + jc.ServeYAML() +} + +func (jc *YAMLController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestYAMLPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/yaml/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/yaml/list", &YAMLController{}) + handler.ServeHTTP(w, r) + if strings.TrimSpace(w.Body.String()) != "prepare" { + t.Errorf(w.Body.String()) + } +} + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} diff --git a/pkg/session/README.md b/pkg/session/README.md new file mode 100644 index 00000000..6d0a297e --- /dev/null +++ b/pkg/session/README.md @@ -0,0 +1,114 @@ +session +============== + +session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. + +## How to install? + + go get github.com/astaxie/beego/session + + +## What providers are supported? + +As of now this session manager support memory, file, Redis and MySQL. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/session" + ) + +Then in you web app init the global session manager + + var globalSessions *session.Manager + +* Use **memory** as provider: + + func init() { + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) + go globalSessions.GC() + } + +* Use **file** as provider, the last param is the path where you want file to be stored: + + func init() { + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) + go globalSessions.GC() + } + +* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: + + func init() { + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) + go globalSessions.GC() + } + +* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): + + func init() { + globalSessions, _ = session.NewManager( + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) + go globalSessions.GC() + } + +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + + +Finally in the handlerfunc you can use it like this + + func login(w http.ResponseWriter, r *http.Request) { + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + username := sess.Get("username") + fmt.Println(username) + if r.Method == "GET" { + t, _ := template.ParseFiles("login.gtpl") + t.Execute(w, nil) + } else { + fmt.Println("username:", r.Form["username"]) + sess.Set("username", r.Form["username"]) + fmt.Println("password:", r.Form["password"]) + } + } + + +## How to write own provider? + +When you develop a web app, maybe you want to write own provider because you must meet the requirements. + +Writing a provider is easy. You only need to define two struct types +(Session and Provider), which satisfy the interface definition. +Maybe you will find the **memory** provider is a good example. + + type SessionStore interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data + } + + type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (SessionStore, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (SessionStore, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() + } + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..707d042c --- /dev/null +++ b/pkg/session/couchbase/sess_couchbase.go @@ -0,0 +1,247 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "net/http" + "strings" + "sync" + + couchbase "github.com/couchbase/go-couchbase" + + "github.com/astaxie/beego/session" +) + +var couchbpder = &Provider{} + +// SessionStore store each session +type SessionStore struct { + b *couchbase.Bucket + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Provider couchabse provided +type Provider struct { + maxlifetime int64 + savePath string + pool string + bucket string + b *couchbase.Bucket +} + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values[key] = value + return nil +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + cs.lock.RLock() + defer cs.lock.RUnlock() + if v, ok := cs.values[key]; ok { + return v + } + return nil +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + delete(cs.values, key) + return nil +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return cs.sid +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + defer cs.b.Close() + + bo, err := session.EncodeGob(cs.values) + if err != nil { + return + } + + cs.b.Set(cs.sid, int(cs.maxlifetime), bo) +} + +func (cp *Provider) getBucket() *couchbase.Bucket { + c, err := couchbase.Connect(cp.savePath) + if err != nil { + return nil + } + + pool, err := c.GetPool(cp.pool) + if err != nil { + return nil + } + + bucket, err := pool.GetBucket(cp.bucket) + if err != nil { + return nil + } + + return bucket +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + cp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + cp.savePath = configs[0] + } + if len(configs) > 1 { + cp.pool = configs[1] + } + if len(configs) > 2 { + cp.bucket = configs[2] + } + + return nil +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var ( + kv map[interface{}]interface{} + err error + doc []byte + ) + + err = cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } else if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + cp.b = cp.getBucket() + defer cp.b.Close() + + var doc []byte + + if err := cp.b.Get(sid, &doc); err != nil || doc == nil { + return false + } + return true +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var doc []byte + if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { + cp.b.Set(sid, int(cp.maxlifetime), "") + } else { + err := cp.b.Delete(oldsid) + if err != nil { + return nil, err + } + _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) + } + + err := cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + cp.b = cp.getBucket() + defer cp.b.Close() + + cp.b.Delete(sid) + return nil +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("couchbase", couchbpder) +} diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go new file mode 100644 index 00000000..ee81df67 --- /dev/null +++ b/pkg/session/ledis/ledis_session.go @@ -0,0 +1,173 @@ +// Package ledis provide session Provider +package ledis + +import ( + "net/http" + "strconv" + "strings" + "sync" + + "github.com/ledisdb/ledisdb/config" + "github.com/ledisdb/ledisdb/ledis" + + "github.com/astaxie/beego/session" +) + +var ( + ledispder = &Provider{} + c *ledis.DB +) + +// SessionStore ledis session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values[key] = value + return nil +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + ls.lock.RLock() + defer ls.lock.RUnlock() + if v, ok := ls.values[key]; ok { + return v + } + return nil +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + delete(ls.values, key) + return nil +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return ls.sid +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(ls.values) + if err != nil { + return + } + c.Set([]byte(ls.sid), b) + c.Expire([]byte(ls.sid), ls.maxlifetime) +} + +// Provider ledis session provider +type Provider struct { + maxlifetime int64 + savePath string + db int +} + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + var err error + lp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) == 1 { + lp.savePath = configs[0] + } else if len(configs) == 2 { + lp.savePath = configs[0] + lp.db, err = strconv.Atoi(configs[1]) + if err != nil { + return err + } + } + cfg := new(config.Config) + cfg.DataDir = lp.savePath + + var ledisInstance *ledis.Ledis + ledisInstance, err = ledis.Open(cfg) + if err != nil { + return err + } + c, err = ledisInstance.Select(lp.db) + return err +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + var ( + kv map[interface{}]interface{} + err error + ) + + kvs, _ := c.Get([]byte(sid)) + + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob(kvs); err != nil { + return nil, err + } + } + + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + count, _ := c.Exists([]byte(sid)) + return count != 0 +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set([]byte(sid), []byte("")) + c.Expire([]byte(sid), lp.maxlifetime) + } else { + data, _ := c.Get([]byte(oldsid)) + c.Set([]byte(sid), data) + c.Expire([]byte(sid), lp.maxlifetime) + } + return lp.SessionRead(sid) +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + c.Del([]byte(sid)) + return nil +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return 0 +} +func init() { + session.Register("ledis", ledispder) +} diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go new file mode 100644 index 00000000..85a2d815 --- /dev/null +++ b/pkg/session/memcache/sess_memcache.go @@ -0,0 +1,230 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "net/http" + "strings" + "sync" + + "github.com/astaxie/beego/session" + + "github.com/bradfitz/gomemcache/memcache" +) + +var mempder = &MemProvider{} +var client *memcache.Client + +// SessionStore memcache session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} + client.Set(&item) +} + +// MemProvider memcache session provider +type MemProvider struct { + maxlifetime int64 + conninfo []string + poolsize int + password string +} + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + rp.conninfo = strings.Split(savePath, ";") + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + item, err := client.Get(sid) + if err != nil { + if err == memcache.ErrCacheMiss { + rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} + return rs, nil + } + return nil, err + } + var kv map[interface{}]interface{} + if len(item.Value) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(item.Value) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + if client == nil { + if err := rp.connectInit(); err != nil { + return false + } + } + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + var contain []byte + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + item.Key = sid + item.Value = []byte("") + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + } else { + client.Delete(oldsid) + item.Key = sid + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + contain = item.Value + } + + var kv map[interface{}]interface{} + if len(contain) == 0 { + kv = make(map[interface{}]interface{}) + } else { + var err error + kv, err = session.DecodeGob(contain) + if err != nil { + return nil, err + } + } + + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + if client == nil { + if err := rp.connectInit(); err != nil { + return err + } + } + + return client.Delete(sid) +} + +func (rp *MemProvider) connectInit() error { + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return 0 +} + +func init() { + session.Register("memcache", mempder) +} diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go new file mode 100644 index 00000000..301353ab --- /dev/null +++ b/pkg/session/mysql/sess_mysql.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = "session" + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", + b, time.Now().Unix(), st.sid) +} + +// Provider mysql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to mysql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("mysql", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + sid, "", time.Now().Unix()) + } + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + } + c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Close() +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("mysql", mysqlpder) +} diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..0b8b9645 --- /dev/null +++ b/pkg/session/postgres/sess_postgresql.go @@ -0,0 +1,243 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import postgresql Driver + _ "github.com/lib/pq" +) + +var postgresqlpder = &Provider{} + +// SessionStore postgresql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", + b, time.Now().Format(time.RFC3339), st.sid) + +} + +// Provider postgresql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to postgresql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("postgres", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + sid, "", time.Now().Format(time.RFC3339)) + + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + oldsid, "", time.Now().Format(time.RFC3339)) + } + c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM session where session_key=$1", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) + c.Close() +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("postgresql", postgresqlpder) +} diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go new file mode 100644 index 00000000..5c382d61 --- /dev/null +++ b/pkg/session/redis/sess_redis.go @@ -0,0 +1,261 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/session" + + "github.com/gomodule/redigo/redis" +) + +var redispder = &Provider{} + +// MaxPoolSize redis max pool size +var MaxPoolSize = 100 + +// SessionStore redis session store +type SessionStore struct { + p *redis.Pool + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p.Get() + defer c.Close() + c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) +} + +// Provider redis session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Pool +} + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + var idleTimeout time.Duration = 0 + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second + } + } + rp.poollist = &redis.Pool{ + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", rp.savePath) + if err != nil { + return nil, err + } + if rp.password != "" { + if _, err = c.Do("AUTH", rp.password); err != nil { + c.Close() + return nil, err + } + } + // some redis proxy such as twemproxy is not support select command + if rp.dbNum > 0 { + _, err = c.Do("SELECT", rp.dbNum) + if err != nil { + c.Close() + return nil, err + } + } + return c, err + }, + MaxIdle: rp.poolsize, + } + + rp.poollist.IdleTimeout = idleTimeout + + return rp.poollist.Get().Err() +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + var kv map[interface{}]interface{} + + kvs, err := redis.String(c.Do("GET", sid)) + if err != nil && err != redis.ErrNil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist.Get() + defer c.Close() + + if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Do("SET", sid, "", "EX", rp.maxlifetime) + } else { + c.Do("RENAME", oldsid, sid) + c.Do("EXPIRE", sid, rp.maxlifetime) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist.Get() + defer c.Close() + + c.Do("DEL", sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis", redispder) +} diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..2fe300df --- /dev/null +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -0,0 +1,220 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster +import ( + "net/http" + "strconv" + "strings" + "sync" + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" + "time" +) + +var redispder = &Provider{} + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = 1000 + +// SessionStore redis_cluster session store +type SessionStore struct { + p *rediss.ClusterClient + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) +} + +// Provider redis_cluster session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *rediss.ClusterClient +} + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + }) + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != rediss.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_cluster", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..6ecb2977 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "github.com/astaxie/beego/session" + "github.com/go-redis/redis" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +var redispder = &Provider{} + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = 100 + +// SessionStore redis_sentinel session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) +} + +// Provider redis_sentinel session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Client + masterName string +} + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = DefaultPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = DefaultPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + if len(configs) > 4 { + if configs[4] != "" { + rp.masterName = configs[4] + } else { + rp.masterName = "mymaster" + } + } else { + rp.masterName = "mymaster" + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + MasterName: rp.masterName, + }) + + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_sentinel", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..fd4155c6 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + //todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/session/sess_cookie.go b/pkg/session/sess_cookie.go new file mode 100644 index 00000000..6ad5debc --- /dev/null +++ b/pkg/session/sess_cookie.go @@ -0,0 +1,180 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "net/http" + "net/url" + "sync" +) + +var cookiepder = &CookieProvider{} + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore struct { + sid string + values map[interface{}]interface{} // session data + lock sync.RWMutex +} + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + st.lock.Lock() + encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) + st.lock.Unlock() + if err == nil { + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(encodedCookie), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure, + MaxAge: cookiepder.config.Maxage} + http.SetCookie(w, cookie) + } +} + +type cookieConfig struct { + SecurityKey string `json:"securityKey"` + BlockKey string `json:"blockKey"` + SecurityName string `json:"securityName"` + CookieName string `json:"cookieName"` + Secure bool `json:"secure"` + Maxage int `json:"maxage"` +} + +// CookieProvider Cookie session provider +type CookieProvider struct { + maxlifetime int64 + config *cookieConfig + block cipher.Block +} + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + pder.config = &cookieConfig{} + err := json.Unmarshal([]byte(config), pder.config) + if err != nil { + return err + } + if pder.config.BlockKey == "" { + pder.config.BlockKey = string(generateRandomKey(16)) + } + if pder.config.SecurityName == "" { + pder.config.SecurityName = string(generateRandomKey(20)) + } + pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) + if err != nil { + return err + } + pder.maxlifetime = maxlifetime + return nil +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + maps, _ := decodeCookie(pder.block, + pder.config.SecurityKey, + pder.config.SecurityName, + sid, pder.maxlifetime) + if maps == nil { + maps = make(map[interface{}]interface{}) + } + rs := &CookieSessionStore{sid: sid, values: maps} + return rs, nil +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + return true +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + return nil, nil +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return nil +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return 0 +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return nil +} + +func init() { + Register("cookie", cookiepder) +} diff --git a/pkg/session/sess_cookie_test.go b/pkg/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/session/sess_file.go b/pkg/session/sess_file.go new file mode 100644 index 00000000..47ad54a7 --- /dev/null +++ b/pkg/session/sess_file.go @@ -0,0 +1,315 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + filepder = &FileProvider{} + gcmaxlifetime int64 +) + +// FileSessionStore File session store +type FileSessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values[key] = value + return nil +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + fs.lock.RLock() + defer fs.lock.RUnlock() + if v, ok := fs.values[key]; ok { + return v + } + return nil +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + delete(fs.values, key) + return nil +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return fs.sid +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + b, err := EncodeGob(fs.values) + if err != nil { + SLogger.Println(err) + return + } + _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + if err != nil { + SLogger.Println(err) + return + } + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + if err != nil { + SLogger.Println(err) + return + } + } else { + return + } + f.Truncate(0) + f.Seek(0, 0) + f.Write(b) + f.Close() +} + +// FileProvider File session provider +type FileProvider struct { + lock sync.RWMutex + maxlifetime int64 + savePath string +} + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + fp.maxlifetime = maxlifetime + fp.savePath = savePath + return nil +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + invalidChars := "./" + if strings.ContainsAny(sid, invalidChars) { + return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) + } + if len(sid) < 2 { + return nil, errors.New("length of the sid is less than 2") + } + filepder.lock.Lock() + defer filepder.lock.Unlock() + + err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) + if err != nil { + SLogger.Println(err.Error()) + } + _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } else { + return nil, err + } + + defer f.Close() + + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + var kv map[interface{}]interface{} + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + if len(sid) < 2 { + SLogger.Println("min length of session id is 2", sid) + return false + } + + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return err == nil +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + filepder.lock.Lock() + defer filepder.lock.Unlock() + os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return nil +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + gcmaxlifetime = fp.maxlifetime + filepath.Walk(fp.savePath, gcpath) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + a := &activeSession{} + err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { + return a.visit(path, f, err) + }) + if err != nil { + SLogger.Printf("filepath.Walk() returned %v\n", err) + return 0 + } + return a.total +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := path.Join(oldPath, oldsid) + newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := path.Join(newPath, sid) + + // new sid file is exist + _, err := os.Stat(newSidFile) + if err == nil { + return nil, fmt.Errorf("newsid %s exist", newSidFile) + } + + err = os.MkdirAll(newPath, 0755) + if err != nil { + SLogger.Println(err.Error()) + } + + // if old sid file exist + // 1.read and parse file content + // 2.write content to new sid file + // 3.remove old sid file, change new sid file atime and ctime + // 4.return FileSessionStore + _, err = os.Stat(oldSidFile) + if err == nil { + b, err := ioutil.ReadFile(oldSidFile) + if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ioutil.WriteFile(newSidFile, b, 0777) + os.Remove(oldSidFile) + os.Chtimes(newSidFile, time.Now(), time.Now()) + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil + } + + // if old sid file not exist, just create new sid file and return + newf, err := os.Create(newSidFile) + if err != nil { + return nil, err + } + newf.Close() + ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} + return ss, nil +} + +// remove file in save path if expired +func gcpath(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { + os.Remove(path) + } + return nil +} + +type activeSession struct { + total int +} + +func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { + if err != nil { + return err + } + if f.IsDir() { + return nil + } + as.total = as.total + 1 + return nil +} + +func init() { + Register("file", filepder) +} diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go new file mode 100644 index 00000000..0cf021db --- /dev/null +++ b/pkg/session/sess_file_test.go @@ -0,0 +1,387 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + + s.Set(i,i) + s.SessionRelease(nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(i).(int) != i { + t.Error() + } + } +} \ No newline at end of file diff --git a/pkg/session/sess_mem.go b/pkg/session/sess_mem.go new file mode 100644 index 00000000..64d8b056 --- /dev/null +++ b/pkg/session/sess_mem.go @@ -0,0 +1,196 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "container/list" + "net/http" + "sync" + "time" +) + +var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore struct { + sid string //session id + timeAccessed time.Time //last access time + value map[interface{}]interface{} //session store + lock sync.RWMutex +} + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.value[key] = value + return nil +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.value[key]; ok { + return v + } + return nil +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.value, key) + return nil +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.value = make(map[interface{}]interface{}) + return nil +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +} + +// MemProvider Implement the provider interface +type MemProvider struct { + lock sync.RWMutex // locker + sessions map[string]*list.Element // map in memory + list *list.List // for gc + maxlifetime int64 + savePath string +} + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + pder.maxlifetime = maxlifetime + pder.savePath = savePath + return nil +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[sid]; ok { + go pder.SessionUpdate(sid) + pder.lock.RUnlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + pder.lock.RLock() + defer pder.lock.RUnlock() + if _, ok := pder.sessions[sid]; ok { + return true + } + return false +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[oldsid]; ok { + go pder.SessionUpdate(oldsid) + pder.lock.RUnlock() + pder.lock.Lock() + element.Value.(*MemSessionStore).sid = sid + pder.sessions[sid] = element + delete(pder.sessions, oldsid) + pder.lock.Unlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + delete(pder.sessions, sid) + pder.list.Remove(element) + return nil + } + return nil +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + pder.lock.RLock() + for { + element := pder.list.Back() + if element == nil { + break + } + if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { + pder.lock.RUnlock() + pder.lock.Lock() + pder.list.Remove(element) + delete(pder.sessions, element.Value.(*MemSessionStore).sid) + pder.lock.Unlock() + pder.lock.RLock() + } else { + break + } + } + pder.lock.RUnlock() +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return pder.list.Len() +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + element.Value.(*MemSessionStore).timeAccessed = time.Now() + pder.list.MoveToFront(element) + return nil + } + return nil +} + +func init() { + Register("memory", mempder) +} diff --git a/pkg/session/sess_mem_test.go b/pkg/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/session/sess_test.go b/pkg/session/sess_test.go new file mode 100644 index 00000000..906abec2 --- /dev/null +++ b/pkg/session/sess_test.go @@ -0,0 +1,131 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "encoding/json" + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} + +func TestGenerate(t *testing.T) { + str := generateRandomKey(20) + if len(str) != 20 { + t.Fatal("generate length is not equal to 20") + } +} + +func TestCookieEncodeDecode(t *testing.T) { + hashKey := "testhashKey" + blockkey := generateRandomKey(16) + block, err := aes.NewCipher(blockkey) + if err != nil { + t.Fatal("NewCipher:", err) + } + securityName := string(generateRandomKey(20)) + val := make(map[interface{}]interface{}) + val["name"] = "astaxie" + val["gender"] = "male" + str, err := encodeCookie(block, hashKey, securityName, val) + if err != nil { + t.Fatal("encodeCookie:", err) + } + dst, err := decodeCookie(block, hashKey, securityName, str, 3600) + if err != nil { + t.Fatal("decodeCookie", err) + } + if dst["name"] != "astaxie" { + t.Fatal("dst get map error") + } + if dst["gender"] != "male" { + t.Fatal("dst get map error") + } +} + +func TestParseConfig(t *testing.T) { + s := `{"cookieName":"gosessionid","gclifetime":3600}` + cf := new(ManagerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(s), cf) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + + cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + cf2 := new(ManagerConfig) + cf2.EnableSetCookie = true + err = json.Unmarshal([]byte(cc), cf2) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf2.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf2.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + if cf2.EnableSetCookie { + t.Fatal("parseconfig get enableSetCookie error") + } + cconfig := new(cookieConfig) + err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) + if err != nil { + t.Fatal("parse ProviderConfig err,", err) + } + if cconfig.CookieName != "gosessionid" { + t.Fatal("ProviderConfig get cookieName error") + } + if cconfig.SecurityKey != "beegocookiehashkey" { + t.Fatal("ProviderConfig get securityKey error") + } +} diff --git a/pkg/session/sess_utils.go b/pkg/session/sess_utils.go new file mode 100644 index 00000000..20915bb6 --- /dev/null +++ b/pkg/session/sess_utils.go @@ -0,0 +1,207 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "io" + "strconv" + "time" + + "github.com/astaxie/beego/utils" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + for _, v := range obj { + gob.Register(v) + } + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} + +// generateRandomKey creates a random key with the given strength. +func generateRandomKey(strength int) []byte { + k := make([]byte, strength) + if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { + return utils.RandomCreateBytes(strength) + } + return k +} + +// Encryption ----------------------------------------------------------------- + +// encrypt encrypts a value using the given block in counter mode. +// +// A random initialization vector (http://goo.gl/zF67k) with the length of the +// block size is prepended to the resulting ciphertext. +func encrypt(block cipher.Block, value []byte) ([]byte, error) { + iv := generateRandomKey(block.BlockSize()) + if iv == nil { + return nil, errors.New("encrypt: failed to generate random iv") + } + // Encrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + // Return iv + ciphertext. + return append(iv, value...), nil +} + +// decrypt decrypts a value using the given block in counter mode. +// +// The value to be decrypted must be prepended by a initialization vector +// (http://goo.gl/zF67k) with the length of the block size. +func decrypt(block cipher.Block, value []byte) ([]byte, error) { + size := block.BlockSize() + if len(value) > size { + // Extract iv. + iv := value[:size] + // Extract ciphertext. + value = value[size:] + // Decrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + return value, nil + } + return nil, errors.New("decrypt: the value could not be decrypted") +} + +func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { + var err error + var b []byte + // 1. EncodeGob. + if b, err = EncodeGob(value); err != nil { + return "", err + } + // 2. Encrypt (optional). + if b, err = encrypt(block, b); err != nil { + return "", err + } + b = encode(b) + // 3. Create MAC for "name|date|value". Extra pipe to be used later. + b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + // Append mac, remove name. + b = append(b, sig...)[len(name)+1:] + // 4. Encode to base64. + b = encode(b) + // Done. + return string(b), nil +} + +func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { + // 1. Decode from base64. + b, err := decode([]byte(value)) + if err != nil { + return nil, err + } + // 2. Verify MAC. Value is "date|value|mac". + parts := bytes.SplitN(b, []byte("|"), 3) + if len(parts) != 3 { + return nil, errors.New("Decode: invalid value format") + } + + b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { + return nil, errors.New("Decode: the value is not valid") + } + // 3. Verify date ranges. + var t1 int64 + if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { + return nil, errors.New("Decode: invalid timestamp") + } + t2 := time.Now().UTC().Unix() + if t1 > t2 { + return nil, errors.New("Decode: timestamp is too new") + } + if t1 < t2-gcmaxlifetime { + return nil, errors.New("Decode: expired timestamp") + } + // 4. Decrypt (optional). + b, err = decode(parts[1]) + if err != nil { + return nil, err + } + if b, err = decrypt(block, b); err != nil { + return nil, err + } + // 5. DecodeGob. + dst, err := DecodeGob(b) + if err != nil { + return nil, err + } + return dst, nil +} + +// Encoding ------------------------------------------------------------------- + +// encode encodes a value using base64. +func encode(value []byte) []byte { + encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) + base64.URLEncoding.Encode(encoded, value) + return encoded +} + +// decode decodes a cookie using base64. +func decode(value []byte) ([]byte, error) { + decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) + b, err := base64.URLEncoding.Decode(decoded, value) + if err != nil { + return nil, err + } + return decoded[:b], nil +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 00000000..eb85360a --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,377 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/textproto" + "net/url" + "os" + "time" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() +} + +var provides = make(map[string]Provider) + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + if provide == nil { + panic("session: Register provide is nil") + } + if _, dup := provides[name]; dup { + panic("session: Register called twice for provider " + name) + } + provides[name] = provide +} + +//GetProvider +func GetProvider(name string) (Provider, error) { + provider, ok := provides[name] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) + } + return provider, nil +} + +// ManagerConfig define the session config +type ManagerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` +} + +// Manager contains Provider and its configuration. +type Manager struct { + provider Provider + config *ManagerConfig +} + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + provider, ok := provides[provideName] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) + } + + if cf.Maxlifetime == 0 { + cf.Maxlifetime = cf.Gclifetime + } + + if cf.EnableSidInHTTPHeader { + if cf.SessionNameInHTTPHeader == "" { + panic(errors.New("SessionNameInHTTPHeader is empty")) + } + + strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) + if cf.SessionNameInHTTPHeader != strMimeHeader { + strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader + panic(errors.New(strErrMsg)) + } + } + + err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + if err != nil { + return nil, err + } + + if cf.SessionIDLength == 0 { + cf.SessionIDLength = 16 + } + + return &Manager{ + provider, + cf, + }, nil +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return manager.provider +} + +// getSid retrieves session identifier from HTTP Request. +// First try to retrieve id by reading from cookie, session cookie name is configurable, +// if not exist, then retrieve id from querying parameters. +// +// error is not nil when there is anything wrong. +// sid is empty when need to generate a new session id +// otherwise return an valid session id. +func (manager *Manager) getSid(r *http.Request) (string, error) { + cookie, errs := r.Cookie(manager.config.CookieName) + if errs != nil || cookie.Value == "" { + var sid string + if manager.config.EnableSidInURLQuery { + errs := r.ParseForm() + if errs != nil { + return "", errs + } + + sid = r.FormValue(manager.config.CookieName) + } + + // if not found in Cookie / param, then read it from request headers + if manager.config.EnableSidInHTTPHeader && sid == "" { + sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] + if isFound && len(sids) != 0 { + return sids[0], nil + } + } + + return sid, nil + } + + // HTTP Request contains cookie for sessionid info. + return url.QueryUnescape(cookie.Value) +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { + sid, errs := manager.getSid(r) + if errs != nil { + return nil, errs + } + + if sid != "" && manager.provider.SessionExist(sid) { + return manager.provider.SessionRead(sid) + } + + // Generate a new session + sid, errs = manager.sessionID() + if errs != nil { + return nil, errs + } + + session, err = manager.provider.SessionRead(sid) + if err != nil { + return nil, err + } + cookie := &http.Cookie{ + Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + if manager.config.EnableSidInHTTPHeader { + r.Header.Del(manager.config.SessionNameInHTTPHeader) + w.Header().Del(manager.config.SessionNameInHTTPHeader) + } + + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + return + } + + sid, _ := url.QueryUnescape(cookie.Value) + manager.provider.SessionDestroy(sid) + if manager.config.EnableSetCookie { + expiration := time.Now() + cookie = &http.Cookie{Name: manager.config.CookieName, + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Expires: expiration, + MaxAge: -1, + Domain: manager.config.Domain} + + http.SetCookie(w, cookie) + } +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { + sessions, err = manager.provider.SessionRead(sid) + return +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + manager.provider.SessionGC() + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { + sid, err := manager.sessionID() + if err != nil { + return + } + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + //delete old cookie + session, _ = manager.provider.SessionRead(sid) + cookie = &http.Cookie{Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + } else { + oldsid, _ := url.QueryUnescape(cookie.Value) + session, _ = manager.provider.SessionRegenerate(oldsid, sid) + cookie.Value = url.QueryEscape(sid) + cookie.HttpOnly = true + cookie.Path = "/" + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return manager.provider.SessionAll() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + manager.config.Secure = secure +} + +func (manager *Manager) sessionID() (string, error) { + b := make([]byte, manager.config.SessionIDLength) + n, err := rand.Read(b) + if n != len(b) || err != nil { + return "", fmt.Errorf("Could not successfully read from the system CSPRNG") + } + return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil +} + +// Set cookie with https. +func (manager *Manager) isSecure(req *http.Request) bool { + if !manager.config.Secure { + return false + } + if req.URL.Scheme != "" { + return req.URL.Scheme == "https" + } + if req.TLS == nil { + return false + } + return true +} + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + sl := new(Log) + sl.Logger = log.New(out, "[SESSION]", 1e9) + return sl +} diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..de0c6360 --- /dev/null +++ b/pkg/session/ssdb/sess_ssdb.go @@ -0,0 +1,199 @@ +package ssdb + +import ( + "errors" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" + "github.com/ssdb/gossdb/ssdb" +) + +var ssdbProvider = &Provider{} + +// Provider holds ssdb client and configs +type Provider struct { + client *ssdb.Client + host string + port int + maxLifetime int64 +} + +func (p *Provider) connectInit() error { + var err error + if p.host == "" || p.port == 0 { + return errors.New("SessionInit First") + } + p.client, err = ssdb.Connect(p.host, p.port) + return err +} + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + p.maxLifetime = maxLifetime + address := strings.Split(savePath, ":") + p.host = address[0] + + var err error + if p.port, err = strconv.Atoi(address[1]); err != nil { + return err + } + return p.connectInit() +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + var kv map[interface{}]interface{} + value, err := p.client.Get(sid) + if err != nil { + return nil, err + } + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + if p.client == nil { + if err := p.connectInit(); err != nil { + panic(err) + } + } + value, err := p.client.Get(sid) + if err != nil { + panic(err) + } + if value == nil || len(value.(string)) == 0 { + return false + } + return true +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + //conn.Do("setx", key, v, ttl) + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + value, err := p.client.Get(oldsid) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + _, err = p.client.Del(oldsid) + if err != nil { + return nil, err + } + } + _, e := p.client.Do("setx", sid, value, p.maxLifetime) + if e != nil { + return nil, e + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + if p.client == nil { + if err := p.connectInit(); err != nil { + return err + } + } + _, err := p.client.Del(sid) + return err +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return 0 +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxLifetime int64 + client *ssdb.Client +} + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + s.values[key] = value + return nil +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + s.lock.Lock() + defer s.lock.Unlock() + if value, ok := s.values[key]; ok { + return value + } + return nil +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.values, key) + return nil +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + s.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return s.sid +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(s.values) + if err != nil { + return + } + s.client.Do("setx", s.sid, string(b), s.maxLifetime) +} + +func init() { + session.Register("ssdb", ssdbProvider) +} diff --git a/pkg/staticfile.go b/pkg/staticfile.go new file mode 100644 index 00000000..84e9aa7b --- /dev/null +++ b/pkg/staticfile.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/hashicorp/golang-lru" +) + +var errNotStaticRequest = errors.New("request not a static file request") + +func serverStaticRouter(ctx *context.Context) { + if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { + return + } + + forbidden, filePath, fileInfo, err := lookupFile(ctx) + if err == errNotStaticRequest { + return + } + + if forbidden { + exception("403", ctx) + return + } + + if filePath == "" || fileInfo == nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't find/open the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + if fileInfo.IsDir() { + requestURL := ctx.Input.URL() + if requestURL[len(requestURL)-1] != '/' { + redirectURL := requestURL + "/" + if ctx.Request.URL.RawQuery != "" { + redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery + } + ctx.Redirect(302, redirectURL) + } else { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + } + return + } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { + //over size file serve with http module + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return + } + + var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + var acceptEncoding string + if enableCompress { + acceptEncoding = context.ParseEncoding(ctx.Request) + } + b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) + if err != nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't compress the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + + if b { + ctx.Output.Header("Content-Encoding", n) + } else { + ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) + } + + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) +} + +type serveContentHolder struct { + data []byte + modTime time.Time + size int64 + originSize int64 //original file size:to judge file changed + encoding string +} + +type serveContentReader struct { + *bytes.Reader +} + +var ( + staticFileLruCache *lru.Cache + lruLock sync.RWMutex +) + +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { + if staticFileLruCache == nil { + //avoid lru cache error + if BConfig.WebConfig.StaticCacheFileNum >= 1 { + staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) + } else { + staticFileLruCache, _ = lru.New(1) + } + } + mapKey := acceptEncoding + ":" + filePath + lruLock.RLock() + var mapFile *serveContentHolder + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + lruLock.RUnlock() + if isOk(mapFile, fi) { + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil + } + lruLock.Lock() + defer lruLock.Unlock() + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + if !isOk(mapFile, fi) { + file, err := os.Open(filePath) + if err != nil { + return false, "", nil, nil, err + } + defer file.Close() + var bufferWriter bytes.Buffer + _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) + if err != nil { + return false, "", nil, nil, err + } + mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} + if isOk(mapFile, fi) { + staticFileLruCache.Add(mapKey, mapFile) + } + } + + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil +} + +func isOk(s *serveContentHolder, fi os.FileInfo) bool { + if s == nil { + return false + } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { + return false + } + return s.modTime == fi.ModTime() && s.originSize == fi.Size() +} + +// isStaticCompress detect static files +func isStaticCompress(filePath string) bool { + for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { + if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { + return true + } + } + return false +} + +// searchFile search the file by url path +// if none the static file prefix matches ,return notStaticRequestErr +func searchFile(ctx *context.Context) (string, os.FileInfo, error) { + requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) + // special processing : favicon.ico/robots.txt can be in any static dir + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + file := path.Join(".", requestPath) + if fi, _ := os.Stat(file); fi != nil { + return file, fi, nil + } + for _, staticDir := range BConfig.WebConfig.StaticDir { + filePath := path.Join(staticDir, requestPath) + if fi, _ := os.Stat(filePath); fi != nil { + return filePath, fi, nil + } + } + return "", nil, errNotStaticRequest + } + + for prefix, staticDir := range BConfig.WebConfig.StaticDir { + if !strings.Contains(requestPath, prefix) { + continue + } + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + continue + } + filePath := path.Join(staticDir, requestPath[len(prefix):]) + if fi, err := os.Stat(filePath); fi != nil { + return filePath, fi, err + } + } + return "", nil, errNotStaticRequest +} + +// lookupFile find the file to serve +// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) +// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex +func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { + fp, fi, err := searchFile(ctx) + if fp == "" || fi == nil { + return false, "", nil, err + } + if !fi.IsDir() { + return false, fp, fi, err + } + if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' { + ifp := filepath.Join(fp, "index.html") + if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { + return false, ifp, ifi, err + } + } + return !BConfig.WebConfig.DirectoryIndex, fp, fi, err +} diff --git a/pkg/staticfile_test.go b/pkg/staticfile_test.go new file mode 100644 index 00000000..e46c13ec --- /dev/null +++ b/pkg/staticfile_test.go @@ -0,0 +1,99 @@ +package beego + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +var currentWorkDir, _ = os.Getwd() +var licenseFile = filepath.Join(currentWorkDir, "LICENSE") + +func testOpenFile(encoding string, content []byte, t *testing.T) { + fi, _ := os.Stat(licenseFile) + b, n, sch, reader, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Log(err) + t.Fail() + } + + t.Log("open static file encoding "+n, b) + + assetOpenFileAndContent(sch, reader, content, t) +} +func TestOpenStaticFile_1(t *testing.T) { + file, _ := os.Open(licenseFile) + content, _ := ioutil.ReadAll(file) + testOpenFile("", content, t) +} + +func TestOpenStaticFileGzip_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("gzip", content, t) +} +func TestOpenStaticFileDeflate_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("deflate", content, t) +} + +func TestStaticCacheWork(t *testing.T) { + encodings := []string{"", "gzip", "deflate"} + + fi, _ := os.Stat(licenseFile) + for _, encoding := range encodings { + _, _, first, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + _, _, second, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + address1 := fmt.Sprintf("%p", first) + address2 := fmt.Sprintf("%p", second) + if address1 != address2 { + t.Errorf("encoding '%v' can not hit cache", encoding) + } + } +} + +func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { + t.Log(sch.size, len(content)) + if sch.size != int64(len(content)) { + t.Log("static content file size not same") + t.Fail() + } + bs, _ := ioutil.ReadAll(reader) + for i, v := range content { + if v != bs[i] { + t.Log("content not same") + t.Fail() + } + } + if staticFileLruCache.Len() == 0 { + t.Log("men map is empty") + t.Fail() + } +} diff --git a/pkg/swagger/swagger.go b/pkg/swagger/swagger.go new file mode 100644 index 00000000..a55676cd --- /dev/null +++ b/pkg/swagger/swagger.go @@ -0,0 +1,174 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +// Swagger list the resource +type Swagger struct { + SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` + Infos Information `json:"info" yaml:"info"` + Host string `json:"host,omitempty" yaml:"host,omitempty"` + BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Paths map[string]*Item `json:"paths" yaml:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` + SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information struct { + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Version string `json:"version,omitempty" yaml:"version,omitempty"` + TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"` + + Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"` + License *License `json:"license,omitempty" yaml:"license,omitempty"` +} + +// Contact information for the exposed API. +type Contact struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` + EMail string `json:"email,omitempty" yaml:"email,omitempty"` +} + +// License information for the exposed API. +type License struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} + +// Item Describes the operations available on a single path. +type Item struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Get *Operation `json:"get,omitempty" yaml:"get,omitempty"` + Put *Operation `json:"put,omitempty" yaml:"put,omitempty"` + Post *Operation `json:"post,omitempty" yaml:"post,omitempty"` + Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"` + Options *Operation `json:"options,omitempty" yaml:"options,omitempty"` + Head *Operation `json:"head,omitempty" yaml:"head,omitempty"` + Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"` +} + +// Operation Describes a single API operation on a path. +type Operation struct { + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` +} + +// Parameter Describes a single operation parameter. +type Parameter struct { + In string `json:"in,omitempty" yaml:"in,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required bool `json:"required,omitempty" yaml:"required,omitempty"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` +} + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. + CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` + Default string `json:"default,omitempty" yaml:"default,omitempty"` +} + +// Schema Object allows the definition of input and output data types. +type Schema struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` +} + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"` + AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"` +} + +// Response as they are returned from executing this operation. +type Response struct { + Description string `json:"description" yaml:"description"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` +} + +// Security Allows the definition of a security scheme that can be used by the operations +type Security struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header". + Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". + AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"` + Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. +} + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// ExternalDocs include Additional external documentation +type ExternalDocs struct { + Description string `json:"description,omitempty" yaml:"description,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} diff --git a/pkg/template.go b/pkg/template.go new file mode 100644 index 00000000..59875be7 --- /dev/null +++ b/pkg/template.go @@ -0,0 +1,406 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html/template" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + beegoTplFuncMap = make(template.FuncMap) + beeViewPathTemplateLocked = false + // beeViewPathTemplates caching map and supported template file extensions per view + beeViewPathTemplates = make(map[string]map[string]*template.Template) + templatesLock sync.RWMutex + // beeTemplateExt stores the template extension which will build + beeTemplateExt = []string{"tpl", "html", "gohtml"} + // beeTemplatePreprocessors stores associations of extension -> preprocessor handler + beeTemplateEngines = map[string]templatePreProcessor{} + beeTemplateFS = defaultFSFunc +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + if BConfig.RunMode == DEV { + templatesLock.RLock() + defer templatesLock.RUnlock() + } + if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { + if t, ok := beeTemplates[name]; ok { + var err error + if t.Lookup(name) != nil { + err = t.ExecuteTemplate(wr, name, data) + } else { + err = t.Execute(wr, data) + } + if err != nil { + logs.Trace("template Execute err:", err) + } + return err + } + panic("can't find templatefile in the path:" + viewPath + "/" + name) + } + panic("Unknown view path:" + viewPath) +} + +func init() { + beegoTplFuncMap["dateformat"] = DateFormat + beegoTplFuncMap["date"] = Date + beegoTplFuncMap["compare"] = Compare + beegoTplFuncMap["compare_not"] = CompareNot + beegoTplFuncMap["not_nil"] = NotNil + beegoTplFuncMap["not_null"] = NotNil + beegoTplFuncMap["substr"] = Substr + beegoTplFuncMap["html2str"] = HTML2str + beegoTplFuncMap["str2html"] = Str2html + beegoTplFuncMap["htmlquote"] = Htmlquote + beegoTplFuncMap["htmlunquote"] = Htmlunquote + beegoTplFuncMap["renderform"] = RenderForm + beegoTplFuncMap["assets_js"] = AssetsJs + beegoTplFuncMap["assets_css"] = AssetsCSS + beegoTplFuncMap["config"] = GetConfig + beegoTplFuncMap["map_get"] = MapGet + + // Comparisons + beegoTplFuncMap["eq"] = eq // == + beegoTplFuncMap["ge"] = ge // >= + beegoTplFuncMap["gt"] = gt // > + beegoTplFuncMap["le"] = le // <= + beegoTplFuncMap["lt"] = lt // < + beegoTplFuncMap["ne"] = ne // != + + beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + beegoTplFuncMap[key] = fn + return nil +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root). +// if tf.root="views" and +// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html" +// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html" +func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { + if f == nil { + return err + } + if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 { + return nil + } + if !HasTemplateExt(paths) { + return nil + } + + replace := strings.NewReplacer("\\", "/") + file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/") + subDir := filepath.Dir(file) + + tf.files[subDir] = append(tf.files[subDir], file) + return nil +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + for _, v := range beeTemplateExt { + if strings.HasSuffix(paths, "."+v) { + return true + } + } + return false +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + for _, v := range beeTemplateExt { + if v == ext { + return + } + } + beeTemplateExt = append(beeTemplateExt, ext) +} + +// AddViewPath adds a new path to the supported view paths. +//Can later be used by setting a controller ViewPath to this folder +//will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + if beeViewPathTemplateLocked { + if _, exist := beeViewPathTemplates[viewPath]; exist { + return nil //Ignore if viewpath already exists + } + panic("Can not add new view paths after beego.Run()") + } + beeViewPathTemplates[viewPath] = make(map[string]*template.Template) + return BuildTemplate(viewPath) +} + +func lockViewPaths() { + beeViewPathTemplateLocked = true +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + var err error + fs := beeTemplateFS() + f, err := fs.Open(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.New("dir open err") + } + defer f.Close() + + beeTemplates, ok := beeViewPathTemplates[dir] + if !ok { + panic("Unknown view path: " + dir) + } + self := &templateFile{ + root: dir, + files: make(map[string][]string), + } + err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { + return self.visit(path, f, err) + }) + if err != nil { + fmt.Printf("Walk() returned %v\n", err) + return err + } + buildAllFiles := len(files) == 0 + for _, v := range self.files { + for _, file := range v { + if buildAllFiles || utils.InSlice(file, files) { + templatesLock.Lock() + ext := filepath.Ext(file) + var t *template.Template + if len(ext) == 0 { + t, err = getTemplate(self.root, fs, file, v...) + } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { + t, err = fn(self.root, file, beegoTplFuncMap) + } else { + t, err = getTemplate(self.root, fs, file, v...) + } + if err != nil { + logs.Error("parse template err:", file, err) + templatesLock.Unlock() + return err + } + beeTemplates[file] = t + templatesLock.Unlock() + } + } + } + return nil +} + +func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { + var fileAbsPath string + var rParent string + var err error + if strings.HasPrefix(file, "../") { + rParent = filepath.Join(filepath.Dir(parent), file) + fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) + } else { + rParent = file + fileAbsPath = filepath.Join(root, file) + } + f, err := fs.Open(fileAbsPath) + if err != nil { + panic("can't find template file:" + file) + } + defer f.Close() + data, err := ioutil.ReadAll(f) + if err != nil { + return nil, [][]string{}, err + } + t, err = t.New(file).Parse(string(data)) + if err != nil { + return nil, [][]string{}, err + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, m := range allSub { + if len(m) == 2 { + tl := t.Lookup(m[1]) + if tl != nil { + continue + } + if !HasTemplateExt(m[1]) { + continue + } + _, _, err = getTplDeep(root, fs, m[1], rParent, t) + if err != nil { + return nil, [][]string{}, err + } + } + } + return t, allSub, nil +} + +func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { + t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) + var subMods [][]string + t, subMods, err = getTplDeep(root, fs, file, "", t) + if err != nil { + return nil, err + } + t, err = _getTemplate(t, root, fs, subMods, others...) + + if err != nil { + return nil, err + } + return +} + +func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { + t = t0 + for _, m := range subMods { + if len(m) == 2 { + tpl := t.Lookup(m[1]) + if tpl != nil { + continue + } + //first check filename + for _, otherFile := range others { + if otherFile == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + } + break + } + } + //second check define + for _, otherFile := range others { + var data []byte + fileAbsPath := filepath.Join(root, otherFile) + f, err := fs.Open(fileAbsPath) + if err != nil { + f.Close() + logs.Trace("template file parse error, not success open file:", err) + continue + } + data, err = ioutil.ReadAll(f) + f.Close() + if err != nil { + logs.Trace("template file parse error, not success read file:", err) + continue + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, sub := range allSub { + if len(sub) == 2 && sub[1] == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + if err != nil { + logs.Trace("template parse file err:", err) + } + } + break + } + } + } + } + + } + return +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + beeTemplateFS = fnt +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + BConfig.WebConfig.ViewsPath = path + return BeeApp +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + BConfig.WebConfig.StaticDir[url] = path + return BeeApp +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + delete(BConfig.WebConfig.StaticDir, url) + return BeeApp +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + AddTemplateExt(extension) + beeTemplateEngines[extension] = fn + return BeeApp +} diff --git a/pkg/template_test.go b/pkg/template_test.go new file mode 100644 index 00000000..287faadc --- /dev/null +++ b/pkg/template_test.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "github.com/astaxie/beego/testdata" + "github.com/elazarl/go-bindata-assetfs" + "net/http" + "os" + "path/filepath" + "testing" +) + +var header = `{{define "header"}} +

Hello, astaxie!

+{{end}}` + +var index = ` + + + beego welcome template + + +{{template "block"}} +{{template "header"}} +{{template "blocks/block.tpl"}} + + +` + +var block = `{{define "block"}} +

Hello, blocks!

+{{end}}` + +func TestTemplate(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "header.tpl", + "index.tpl", + "blocks/block.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(header) + } else if k == 1 { + f.WriteString(index) + } else if k == 2 { + f.WriteString(block) + } + + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var menu = ` +` +var user = ` + + + beego welcome template + + +{{template "../public/menu.tpl"}} + + +` + +func TestRelativeTemplate(t *testing.T) { + dir := "_beeTmp" + + //Just add dir to known viewPaths + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + + files := []string{ + "easyui/public/menu.tpl", + "easyui/rbac/user.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(menu) + } else if k == 1 { + f.WriteString(user) + } + f.Close() + } + } + if err := BuildTemplate(dir, files[1]); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var add = `{{ template "layout_blog.tpl" . }} +{{ define "css" }} + +{{ end}} + + +{{ define "content" }} +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+{{ end }} + +{{ define "js" }} + +{{ end}}` + +var layoutBlog = ` + + + Lin Li + + + + + {{ block "css" . }}{{ end }} + + + +
+ {{ block "content" . }}{{ end }} +
+ + + {{ block "js" . }}{{ end }} + +` + +var output = ` + + + Lin Li + + + + + + + + + + +
+ +

Hello

+

This is SomeVar: val

+ +
+ + + + + + + + + + + + +` + +func TestTemplateLayout(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "add.tpl", + "layout_blog.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(add) + } else if k == 1 { + f.WriteString(layoutBlog) + } + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 2 { + t.Fatalf("should be 2 but got %v", len(beeTemplates)) + } + out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + if out.String() != output { + t.Log(out.String()) + t.Fatal("Compare failed") + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +type TestingFileSystem struct { + assetfs *assetfs.AssetFS +} + +func (d TestingFileSystem) Open(name string) (http.File, error) { + return d.assetfs.Open(name) +} + +var outputBinData = ` + + + beego welcome template + + + + +

Hello, blocks!

+ + +

Hello, astaxie!

+ + + +

Hello

+

This is SomeVar: val

+ + +` + +func TestFsBinData(t *testing.T) { + SetTemplateFSFunc(func() http.FileSystem { + return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + }) + dir := "views" + if err := AddViewPath("views"); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + out := bytes.NewBufferString("") + if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + + if out.String() != outputBinData { + t.Log(out.String()) + t.Fatal("Compare failed") + } +} diff --git a/pkg/templatefunc.go b/pkg/templatefunc.go new file mode 100644 index 00000000..ba1ec5eb --- /dev/null +++ b/pkg/templatefunc.go @@ -0,0 +1,780 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html" + "html/template" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + bt := []rune(s) + if start < 0 { + start = 0 + } + if start > len(bt) { + start = start % len(bt) + } + var end int + if (start + length) > (len(bt) - 1) { + end = len(bt) + } else { + end = start + length + } + return string(bt[start:end]) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + + re := regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllStringFunc(html, strings.ToLower) + + //remove STYLE + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + //remove SCRIPT + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + re = regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllString(html, "\n") + + re = regexp.MustCompile(`\s{2,}`) + html = re.ReplaceAllString(html, "\n") + + return strings.TrimSpace(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + datestring = t.Format(layout) + return +} + +// DateFormat pattern rules. +var datePatterns = []string{ + // year + "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 + "y", "06", //A two digit representation of a year Examples: 99 or 03 + + // month + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "F", "January", // A full textual representation of a month, such as January or March January through December + + // day + "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 + + // week + "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday + + // time + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 + "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 + "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 + + "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm + "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM + + "i", "04", // Minutes with leading zeros 00 to 59 + "s", "05", // Seconds, with leading zeros 00 through 59 + + // time zone + "T", "MST", + "P", "-07:00", + "O", "-0700", + + // RFC 2822 + "r", time.RFC1123Z, +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return time.ParseInLocation(format, dateString, time.Local) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return t.Format(format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + equal = false + if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { + equal = true + } + return +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return !Compare(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return CompareNot(a, nil) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { + switch returnType { + case "String": + value = AppConfig.String(key) + case "Bool": + value, err = AppConfig.Bool(key) + case "Int": + value, err = AppConfig.Int(key) + case "Int64": + value, err = AppConfig.Int64(key) + case "Float": + value, err = AppConfig.Float(key) + case "DIY": + value, err = AppConfig.DIY(key) + default: + err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") + } + + if err != nil { + if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) { + err = errors.New("defaultVal type does not match returnType") + } else { + value, err = defaultVal, nil + } + } else if reflect.TypeOf(value).Kind() == reflect.String { + if value == "" { + if reflect.TypeOf(defaultVal).Kind() != reflect.String { + err = errors.New("defaultVal type must be a String if the returnType is a String") + } else { + value = defaultVal.(string) + } + } + } + + return +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return template.HTML(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + //HTML编码为实体符号 + /* + Encodes `text` for raw use in HTML. + >>> htmlquote("<'&\\">") + '<'&">' + */ + + text = html.EscapeString(text) + text = strings.NewReplacer( + `“`, "“", + `”`, "”", + ` `, " ", + ).Replace(text) + + return strings.TrimSpace(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + //实体符号解释为HTML + /* + Decodes `text` that's HTML quoted. + >>> htmlunquote('<'&">') + '<\\'&">' + */ + + text = html.UnescapeString(text) + + return strings.TrimSpace(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return BeeApp.Handlers.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +// Support for anonymous struct. +func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() { + continue + } + + fieldT := objT.Field(i) + if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { + err := parseFormToStruct(form, fieldT.Type, fieldV) + if err != nil { + return err + } + continue + } + + tags := strings.Split(fieldT.Tag.Get("form"), ",") + var tag string + if len(tags) == 0 || len(tags[0]) == 0 { + tag = fieldT.Name + } else if tags[0] == "-" { + continue + } else { + tag = tags[0] + } + + formValues := form[tag] + var value string + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + if defaultValue != "" { + value = defaultValue + } else { + continue + } + } + if len(formValues) == 1 { + value = formValues[0] + if value == "" { + continue + } + } + + switch fieldT.Type.Kind() { + case reflect.Bool: + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + fieldV.SetBool(true) + continue + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + fieldV.SetBool(false) + continue + } + b, err := strconv.ParseBool(value) + if err != nil { + return err + } + fieldV.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + fieldV.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + fieldV.SetFloat(x) + case reflect.Interface: + fieldV.Set(reflect.ValueOf(value)) + case reflect.String: + fieldV.SetString(value) + case reflect.Struct: + switch fieldT.Type.String() { + case "time.Time": + var ( + t time.Time + err error + ) + if len(value) >= 25 { + value = value[:25] + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + value = value[:19] + t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) + } else { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + t, err = time.ParseInLocation(formatDate, value, time.Local) + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + t, err = time.ParseInLocation(formatTime, value, time.Local) + } + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } + case reflect.Slice: + if fieldT.Type == sliceOfInts { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + val, err := strconv.Atoi(formVals[i]) + if err != nil { + return err + } + fieldV.Index(i).SetInt(int64(val)) + } + } else if fieldT.Type == sliceOfStrings { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + fieldV.Index(i).SetString(formVals[i]) + } + } + } + } + return nil +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return fmt.Errorf("%v must be a struct pointer", obj) + } + objT = objT.Elem() + objV = objV.Elem() + + return parseFormToStruct(form, objT, objV) +} + +var sliceOfInts = reflect.TypeOf([]int(nil)) +var sliceOfStrings = reflect.TypeOf([]string(nil)) + +var unKind = map[reflect.Kind]bool{ + reflect.Uintptr: true, + reflect.Complex64: true, + reflect.Complex128: true, + reflect.Array: true, + reflect.Chan: true, + reflect.Func: true, + reflect.Map: true, + reflect.Ptr: true, + reflect.Slice: true, + reflect.Struct: true, + reflect.UnsafePointer: true, +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return template.HTML("") + } + objT = objT.Elem() + objV = objV.Elem() + + var raw []string + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() || unKind[fieldV.Kind()] { + continue + } + + fieldT := objT.Field(i) + + label, name, fType, id, class, ignored, required := parseFormTag(fieldT) + if ignored { + continue + } + + raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) + } + return template.HTML(strings.Join(raw, "
")) +} + +// renderFormField returns a string containing HTML of a single form field. +func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { + if id != "" { + id = " id=\"" + id + "\"" + } + + if class != "" { + class = " class=\"" + class + "\"" + } + + requiredString := "" + if required { + requiredString = " required" + } + + if isValidForInput(fType) { + return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) + } + + return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) +} + +// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. +func isValidForInput(fType string) bool { + validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color") + for _, validType := range validInputTypes { + if fType == validType { + return true + } + } + return false +} + +// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. +// returned are the form label, name-property, type and wether the field should be ignored. +func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { + tags := strings.Split(fieldT.Tag.Get("form"), ",") + label = fieldT.Name + ": " + name = fieldT.Name + fType = "text" + ignored = false + id = fieldT.Tag.Get("id") + class = fieldT.Tag.Get("class") + + required = false + requiredField := fieldT.Tag.Get("required") + if requiredField != "-" && requiredField != "" { + required, _ = strconv.ParseBool(requiredField) + } + + switch len(tags) { + case 1: + if tags[0] == "-" { + ignored = true + } + if len(tags[0]) > 0 { + name = tags[0] + } + case 2: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + case 3: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + if len(tags[2]) > 0 { + label = tags[2] + } + } + + return +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// go1.2 added template funcs. begin +var ( + errBadComparisonType = errors.New("invalid type for comparison") + errBadComparison = errors.New("incompatible types for comparison") + errNoComparison = errors.New("missing argument for comparison") +) + +type kind int + +const ( + invalidKind kind = iota + boolKind + complexKind + intKind + floatKind + stringKind + uintKind +) + +func basicKind(v reflect.Value) (kind, error) { + switch v.Kind() { + case reflect.Bool: + return boolKind, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intKind, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintKind, nil + case reflect.Float32, reflect.Float64: + return floatKind, nil + case reflect.Complex64, reflect.Complex128: + return complexKind, nil + case reflect.String: + return stringKind, nil + } + return invalidKind, errBadComparisonType +} + +// eq evaluates the comparison a == b || a == c || ... +func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + if len(arg2) == 0 { + return false, errNoComparison + } + for _, arg := range arg2 { + v2 := reflect.ValueOf(arg) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind: + truth = v1.Bool() == v2.Bool() + case complexKind: + truth = v1.Complex() == v2.Complex() + case floatKind: + truth = v1.Float() == v2.Float() + case intKind: + truth = v1.Int() == v2.Int() + case stringKind: + truth = v1.String() == v2.String() + case uintKind: + truth = v1.Uint() == v2.Uint() + default: + panic("invalid kind") + } + if truth { + return true, nil + } + } + return false, nil +} + +// ne evaluates the comparison a != b. +func ne(arg1, arg2 interface{}) (bool, error) { + // != is the inverse of ==. + equal, err := eq(arg1, arg2) + return !equal, err +} + +// lt evaluates the comparison a < b. +func lt(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// le evaluates the comparison <= b. +func le(arg1, arg2 interface{}) (bool, error) { + // <= is < or ==. + lessThan, err := lt(arg1, arg2) + if lessThan || err != nil { + return lessThan, err + } + return eq(arg1, arg2) +} + +// gt evaluates the comparison a > b. +func gt(arg1, arg2 interface{}) (bool, error) { + // > is the inverse of <=. + lessOrEqual, err := le(arg1, arg2) + if err != nil { + return false, err + } + return !lessOrEqual, nil +} + +// ge evaluates the comparison a >= b. +func ge(arg1, arg2 interface{}) (bool, error) { + // >= is the inverse of <. + lessThan, err := lt(arg1, arg2) + if err != nil { + return false, err + } + return !lessThan, nil +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + arg1Type := reflect.TypeOf(arg1) + arg1Val := reflect.ValueOf(arg1) + + if arg1Type.Kind() == reflect.Map && len(arg2) > 0 { + // check whether arg2[0] type equals to arg1 key type + // if they are different, make conversion + arg2Val := reflect.ValueOf(arg2[0]) + arg2Type := reflect.TypeOf(arg2[0]) + if arg2Type.Kind() != arg1Type.Key().Kind() { + // convert arg2Value to string + var arg2ConvertedVal interface{} + arg2String := fmt.Sprintf("%v", arg2[0]) + + // convert string representation to any other type + switch arg1Type.Key().Kind() { + case reflect.Bool: + arg2ConvertedVal, _ = strconv.ParseBool(arg2String) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64) + case reflect.Float32, reflect.Float64: + arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64) + case reflect.String: + arg2ConvertedVal = arg2String + default: + arg2ConvertedVal = arg2Val.Interface() + } + arg2Val = reflect.ValueOf(arg2ConvertedVal) + } + + storedVal := arg1Val.MapIndex(arg2Val) + + if storedVal.IsValid() { + var result interface{} + + switch arg1Type.Elem().Kind() { + case reflect.Bool: + result = storedVal.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + result = storedVal.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + result = storedVal.Uint() + case reflect.Float32, reflect.Float64: + result = storedVal.Float() + case reflect.String: + result = storedVal.String() + default: + result = storedVal.Interface() + } + + // if there is more keys, handle this recursively + if len(arg2) > 1 { + return MapGet(result, arg2[1:]...) + } + return result, nil + } + return nil, nil + + } + return nil, nil +} diff --git a/pkg/templatefunc_test.go b/pkg/templatefunc_test.go new file mode 100644 index 00000000..b4c19c2e --- /dev/null +++ b/pkg/templatefunc_test.go @@ -0,0 +1,380 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "html/template" + "net/url" + "reflect" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestRenderFormField(t *testing.T) { + html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for input[type=text]: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } +} + +func TestParseFormTag(t *testing.T) { + // create struct to contain field with different types of struct-tag `form` + type user struct { + All int `form:"name,text,年龄:"` + NoName int `form:",hidden,年龄:"` + OnlyLabel int `form:",,年龄:"` + OnlyName int `form:"name" id:"name" class:"form-name"` + Ignored int `form:"-"` + Required int `form:"name" required:"true"` + IgnoreRequired int `form:"name"` + NotRequired int `form:"name" required:"false"` + } + + objT := reflect.TypeOf(&user{}).Elem() + + label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) + if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag with name, label and type was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) + if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { + t.Errorf("Form Tag with label and type but without name was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) + if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag containing only label was not correctly parsed.") + } + + label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) + if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && + id == "name" && class == "form-name") { + t.Errorf("Form Tag containing only name was not correctly parsed.") + } + + _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) + if !ignored { + t.Errorf("Form Tag that should be ignored was not correctly parsed.") + } + + _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) + if !(name == "name" && required) { + t.Errorf("Form Tag containing only name and required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and not required was not correctly parsed.") + } + +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/testdata/Makefile b/pkg/testdata/Makefile new file mode 100644 index 00000000..e80e8238 --- /dev/null +++ b/pkg/testdata/Makefile @@ -0,0 +1,2 @@ +build_view: + $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/pkg/testdata/bindata.go b/pkg/testdata/bindata.go new file mode 100644 index 00000000..beade103 --- /dev/null +++ b/pkg/testdata/bindata.go @@ -0,0 +1,296 @@ +// Code generated by go-bindata. +// sources: +// views/blocks/block.tpl +// views/header.tpl +// views/index.tpl +// DO NOT EDIT! + +package testdata + +import ( + "bytes" + "compress/gzip" + "fmt" + "github.com/elazarl/go-bindata-assetfs" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +func (fi bindataFileInfo) Name() string { + return fi.name +} +func (fi bindataFileInfo) Size() int64 { + return fi.size +} +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} +func (fi bindataFileInfo) IsDir() bool { + return false +} +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00") + +func viewsBlocksBlockTplBytes() ([]byte, error) { + return bindataRead( + _viewsBlocksBlockTpl, + "views/blocks/block.tpl", + ) +} + +func viewsBlocksBlockTpl() (*asset, error) { + bytes, err := viewsBlocksBlockTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00") + +func viewsHeaderTplBytes() ([]byte, error) { + return bindataRead( + _viewsHeaderTpl, + "views/header.tpl", + ) +} + +func viewsHeaderTpl() (*asset, error) { + bytes, err := viewsHeaderTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00") + +func viewsIndexTplBytes() ([]byte, error) { + return bindataRead( + _viewsIndexTpl, + "views/index.tpl", + ) +} + +func viewsIndexTpl() (*asset, error) { + bytes, err := viewsIndexTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if err != nil { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "views/blocks/block.tpl": viewsBlocksBlockTpl, + "views/header.tpl": viewsHeaderTpl, + "views/index.tpl": viewsIndexTpl, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} + +var _bintree = &bintree{nil, map[string]*bintree{ + "views": &bintree{nil, map[string]*bintree{ + "blocks": &bintree{nil, map[string]*bintree{ + "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, + }}, + "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, + "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, + }}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, filepath.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} + +func assetFS() *assetfs.AssetFS { + assetInfo := func(path string) (os.FileInfo, error) { + return os.Stat(path) + } + for k := range _bintree.Children { + return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} + } + panic("unreachable") +} diff --git a/pkg/testdata/views/blocks/block.tpl b/pkg/testdata/views/blocks/block.tpl new file mode 100644 index 00000000..2a9c57fc --- /dev/null +++ b/pkg/testdata/views/blocks/block.tpl @@ -0,0 +1,3 @@ +{{define "block"}} +

Hello, blocks!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/header.tpl b/pkg/testdata/views/header.tpl new file mode 100644 index 00000000..041fa403 --- /dev/null +++ b/pkg/testdata/views/header.tpl @@ -0,0 +1,3 @@ +{{define "header"}} +

Hello, astaxie!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/index.tpl b/pkg/testdata/views/index.tpl new file mode 100644 index 00000000..21b7fc06 --- /dev/null +++ b/pkg/testdata/views/index.tpl @@ -0,0 +1,15 @@ + + + + beego welcome template + + + + {{template "block"}} + {{template "header"}} + {{template "blocks/block.tpl"}} + +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+ + diff --git a/pkg/testing/assertions.go b/pkg/testing/assertions.go new file mode 100644 index 00000000..96c5d4dd --- /dev/null +++ b/pkg/testing/assertions.go @@ -0,0 +1,15 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing diff --git a/pkg/testing/client.go b/pkg/testing/client.go new file mode 100644 index 00000000..c3737e9c --- /dev/null +++ b/pkg/testing/client.go @@ -0,0 +1,65 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/httplib" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest struct { + httplib.BeegoHTTPRequest +} + +func getPort() string { + if port == "" { + config, err := config.NewConfig("ini", "../conf/app.conf") + if err != nil { + return "8080" + } + port = config.String("httpport") + return port + } + return port +} + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)} +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)} +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)} +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)} +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)} +} diff --git a/pkg/toolbox/healthcheck.go b/pkg/toolbox/healthcheck.go new file mode 100644 index 00000000..e3544b3a --- /dev/null +++ b/pkg/toolbox/healthcheck.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +// AdminCheckList holds health checker map +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker interface { + Check() error +} + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/toolbox/profile.go b/pkg/toolbox/profile.go new file mode 100644 index 00000000..06e40ede --- /dev/null +++ b/pkg/toolbox/profile.go @@ -0,0 +1,184 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "io" + "log" + "os" + "path" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "time" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + switch input { + case "lookup goroutine": + p := pprof.Lookup("goroutine") + p.WriteTo(w, 2) + case "lookup heap": + p := pprof.Lookup("heap") + p.WriteTo(w, 2) + case "lookup threadcreate": + p := pprof.Lookup("threadcreate") + p.WriteTo(w, 2) + case "lookup block": + p := pprof.Lookup("block") + p.WriteTo(w, 2) + case "get cpuprof": + GetCPUProfile(w) + case "get memprof": + MemProf(w) + case "gc summary": + PrintGCSummary(w) + } +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + filename := "mem-" + strconv.Itoa(pid) + ".memprof" + if f, err := os.Create(filename); err != nil { + fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error()) + log.Fatal("record heap profile failed: ", err) + } else { + runtime.GC() + pprof.WriteHeapProfile(f) + f.Close() + fmt.Fprintf(w, "create heap profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) + } +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + sec := 30 + filename := "cpu-" + strconv.Itoa(pid) + ".pprof" + f, err := os.Create(filename) + if err != nil { + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + log.Fatal("record cpu profile failed: ", err) + } + pprof.StartCPUProfile(f) + time.Sleep(time.Duration(sec) * time.Second) + pprof.StopCPUProfile() + + fmt.Fprintf(w, "create cpu profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + memStats := &runtime.MemStats{} + runtime.ReadMemStats(memStats) + gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)} + debug.ReadGCStats(gcstats) + + printGC(memStats, gcstats, w) +} + +func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { + + if gcstats.NumGC > 0 { + lastPause := gcstats.Pause[0] + elapsed := time.Now().Sub(startTime) + overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", + gcstats.NumGC, + toS(lastPause), + toS(avg(gcstats.Pause)), + overhead, + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate)), + toS(gcstats.PauseQuantiles[94]), + toS(gcstats.PauseQuantiles[98]), + toS(gcstats.PauseQuantiles[99])) + } else { + // while GC has disabled + elapsed := time.Now().Sub(startTime) + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate))) + } +} + +func avg(items []time.Duration) time.Duration { + var sum time.Duration + for _, item := range items { + sum += item + } + return time.Duration(int64(sum) / int64(len(items))) +} + +// format bytes number friendly +func toH(bytes uint64) string { + switch { + case bytes < 1024: + return fmt.Sprintf("%dB", bytes) + case bytes < 1024*1024: + return fmt.Sprintf("%.2fK", float64(bytes)/1024) + case bytes < 1024*1024*1024: + return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024) + default: + return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) + } +} + +// short string format +func toS(d time.Duration) string { + + u := uint64(d) + if u < uint64(time.Second) { + switch { + case u == 0: + return "0" + case u < uint64(time.Microsecond): + return fmt.Sprintf("%.2fns", float64(u)) + case u < uint64(time.Millisecond): + return fmt.Sprintf("%.2fus", float64(u)/1000) + default: + return fmt.Sprintf("%.2fms", float64(u)/1000/1000) + } + } else { + switch { + case u < uint64(time.Minute): + return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) + case u < uint64(time.Hour): + return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) + default: + return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) + } + } + +} diff --git a/pkg/toolbox/profile_test.go b/pkg/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/toolbox/statistics.go b/pkg/toolbox/statistics.go new file mode 100644 index 00000000..fd73dfb3 --- /dev/null +++ b/pkg/toolbox/statistics.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "time" +) + +// Statistics struct +type Statistics struct { + RequestURL string + RequestController string + RequestNum int64 + MinTime time.Duration + MaxTime time.Duration + TotalTime time.Duration +} + +// URLMap contains several statistics struct to log different data +type URLMap struct { + lock sync.RWMutex + LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit + urlmap map[string]map[string]*Statistics +} + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + m.lock.Lock() + defer m.lock.Unlock() + if method, ok := m.urlmap[requestURL]; ok { + if s, ok := method[requestMethod]; ok { + s.RequestNum++ + if s.MaxTime < requesttime { + s.MaxTime = requesttime + } + if s.MinTime > requesttime { + s.MinTime = requesttime + } + s.TotalTime += requesttime + } else { + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + m.urlmap[requestURL][requestMethod] = nb + } + + } else { + if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { + return + } + methodmap := make(map[string]*Statistics) + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + methodmap[requestMethod] = nb + m.urlmap[requestURL] = methodmap + } +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} + + var resultLists [][]string + content := make(map[string]interface{}) + content["Fields"] = fields + + for k, v := range m.urlmap { + for kk, vv := range v { + result := []string{ + fmt.Sprintf("% -50s", k), + fmt.Sprintf("% -10s", kk), + fmt.Sprintf("% -16d", vv.RequestNum), + fmt.Sprintf("%d", vv.TotalTime), + fmt.Sprintf("% -16s", toS(vv.TotalTime)), + fmt.Sprintf("%d", vv.MaxTime), + fmt.Sprintf("% -16s", toS(vv.MaxTime)), + fmt.Sprintf("%d", vv.MinTime), + fmt.Sprintf("% -16s", toS(vv.MinTime)), + fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), + fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), + } + resultLists = append(resultLists, result) + } + } + content["Data"] = resultLists + return content +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var resultLists []map[string]interface{} + + for k, v := range m.urlmap { + for kk, vv := range v { + result := map[string]interface{}{ + "request_url": k, + "method": kk, + "times": vv.RequestNum, + "total_time": toS(vv.TotalTime), + "max_time": toS(vv.MaxTime), + "min_time": toS(vv.MinTime), + "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + } + resultLists = append(resultLists, result) + } + } + return resultLists +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = &URLMap{ + urlmap: make(map[string]map[string]*Statistics), + } +} diff --git a/pkg/toolbox/statistics_test.go b/pkg/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/toolbox/task.go b/pkg/toolbox/task.go new file mode 100644 index 00000000..c902fdfc --- /dev/null +++ b/pkg/toolbox/task.go @@ -0,0 +1,640 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "log" + "math" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// bounds provides a range of acceptable values (plus a map of name to value). +type bounds struct { + min, max uint + names map[string]uint +} + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker + taskLock sync.RWMutex + stop chan bool + changed chan bool + isstart bool + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + days = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ + "jan": 1, + "feb": 2, + "mar": 3, + "apr": 4, + "may": 5, + "jun": 6, + "jul": 7, + "aug": 8, + "sep": 9, + "oct": 10, + "nov": 11, + "dec": 12, + }} + weeks = bounds{0, 6, map[string]uint{ + "sun": 0, + "mon": 1, + "tue": 2, + "wed": 3, + "thu": 4, + "fri": 5, + "sat": 6, + }} +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule struct { + Second uint64 + Minute uint64 + Hour uint64 + Day uint64 + Month uint64 + Week uint64 +} + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList +type Task struct { + Taskname string + Spec *Schedule + SpecStr string + DoFunc TaskFunc + Prev time.Time + Next time.Time + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := &Task{ + Taskname: tname, + DoFunc: f, + // Make configurable + ErrLimit: 100, + SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), + } + task.SetCron(spec) + return task +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + return t.SpecStr +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + var str string + for _, v := range t.Errlist { + str += v.t.String() + ":" + v.errinfo + "
" + } + return str +} + +// Run run all tasks +func (t *Task) Run() error { + err := t.DoFunc() + if err != nil { + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ + } + return err +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.Next = t.Spec.Next(now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + return t.Next +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.Prev = now +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + return t.Prev +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//   -:duration +// /n : do as n times of time duration +///////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.Spec = t.parse(spec) +} + +func (t *Task) parse(spec string) *Schedule { + if len(spec) > 0 && spec[0] == '@' { + return t.parseSpec(spec) + } + // Split on whitespace. We require 5 or 6 fields. + // (second) (minute) (hour) (day of month) (month) (day of week, optional) + fields := strings.Fields(spec) + if len(fields) != 5 && len(fields) != 6 { + log.Panicf("Expected 5 or 6 fields, found %d: %s", len(fields), spec) + } + + // If a sixth field is not provided (DayOfWeek), then it is equivalent to star. + if len(fields) == 5 { + fields = append(fields, "*") + } + + schedule := &Schedule{ + Second: getField(fields[0], seconds), + Minute: getField(fields[1], minutes), + Hour: getField(fields[2], hours), + Day: getField(fields[3], days), + Month: getField(fields[4], months), + Week: getField(fields[5], weeks), + } + + return schedule +} + +func (t *Task) parseSpec(spec string) *Schedule { + switch spec { + case "@yearly", "@annually": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: 1 << months.min, + Week: all(weeks), + } + + case "@monthly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: all(months), + Week: all(weeks), + } + + case "@weekly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: 1 << weeks.min, + } + + case "@daily", "@midnight": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: all(weeks), + } + + case "@hourly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: all(hours), + Day: all(days), + Month: all(months), + Week: all(weeks), + } + } + log.Panicf("Unrecognized descriptor: %s", spec) + return nil +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + + // Start at the earliest possible time (the upcoming second). + t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) + + // This flag indicates whether a field has been incremented. + added := false + + // If no time is found within five years, return zero. + yearLimit := t.Year() + 5 + +WRAP: + if t.Year() > yearLimit { + return time.Time{} + } + + // Find the first applicable month. + // If it's this month, then do nothing. + for 1< 0 + dowMatch = 1< 0 + ) + + if s.Day&starBit > 0 || s.Week&starBit > 0 { + return domMatch && dowMatch + } + return domMatch || dowMatch +} + +// StartTask start all tasks +func StartTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + //If already started, no need to start another goroutine. + return + } + isstart = true + go run() +} + +func run() { + now := time.Now().Local() + for _, t := range AdminTaskList { + t.SetNext(now) + } + + for { + // we only use RLock here because NewMapSorter copy the reference, do not change any thing + taskLock.RLock() + sortList := NewMapSorter(AdminTaskList) + taskLock.RUnlock() + sortList.Sort() + var effective time.Time + if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + effective = now.AddDate(10, 0, 0) + } else { + effective = sortList.Vals[0].GetNext() + } + select { + case now = <-time.After(effective.Sub(now)): + // Run every entry whose next time was this effective time. + for _, e := range sortList.Vals { + if e.GetNext() != effective { + break + } + go e.Run() + e.SetPrev(e.GetNext()) + e.SetNext(effective) + } + continue + case <-changed: + now = time.Now().Local() + taskLock.Lock() + for _, t := range AdminTaskList { + t.SetNext(now) + } + taskLock.Unlock() + continue + case <-stop: + return + } + } +} + +// StopTask stop all tasks +func StopTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + isstart = false + stop <- true + } + +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + taskLock.Lock() + defer taskLock.Unlock() + t.SetNext(time.Now().Local()) + AdminTaskList[taskname] = t + if isstart { + changed <- true + } +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + taskLock.Lock() + defer taskLock.Unlock() + delete(AdminTaskList, taskname) + if isstart { + changed <- true + } +} + +// MapSorter sort map for tasker +type MapSorter struct { + Keys []string + Vals []Tasker +} + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + ms := &MapSorter{ + Keys: make([]string, 0, len(m)), + Vals: make([]Tasker, 0, len(m)), + } + for k, v := range m { + ms.Keys = append(ms.Keys, k) + ms.Vals = append(ms.Vals, v) + } + return ms +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext().IsZero() { + return false + } + if ms.Vals[j].GetNext().IsZero() { + return true + } + return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func getField(field string, r bounds) uint64 { + // list = range {"," range} + var bits uint64 + ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) + for _, expr := range ranges { + bits |= getRange(expr, r) + } + return bits +} + +// getRange returns the bits indicated by the given expression: +// number | number "-" number [ "/" number ] +func getRange(expr string, r bounds) uint64 { + + var ( + start, end, step uint + rangeAndStep = strings.Split(expr, "/") + lowAndHigh = strings.Split(rangeAndStep[0], "-") + singleDigit = len(lowAndHigh) == 1 + ) + + var extrastar uint64 + if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { + start = r.min + end = r.max + extrastar = starBit + } else { + start = parseIntOrName(lowAndHigh[0], r.names) + switch len(lowAndHigh) { + case 1: + end = start + case 2: + end = parseIntOrName(lowAndHigh[1], r.names) + default: + log.Panicf("Too many hyphens: %s", expr) + } + } + + switch len(rangeAndStep) { + case 1: + step = 1 + case 2: + step = mustParseInt(rangeAndStep[1]) + + // Special handling: "N/step" means "N-max/step". + if singleDigit { + end = r.max + } + default: + log.Panicf("Too many slashes: %s", expr) + } + + if start < r.min { + log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr) + } + if end > r.max { + log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr) + } + if start > end { + log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr) + } + + return getBits(start, end, step) | extrastar +} + +// parseIntOrName returns the (possibly-named) integer contained in expr. +func parseIntOrName(expr string, names map[string]uint) uint { + if names != nil { + if namedInt, ok := names[strings.ToLower(expr)]; ok { + return namedInt + } + } + return mustParseInt(expr) +} + +// mustParseInt parses the given expression as an int or panics. +func mustParseInt(expr string) uint { + num, err := strconv.Atoi(expr) + if err != nil { + log.Panicf("Failed to parse int from %s: %s", expr, err) + } + if num < 0 { + log.Panicf("Negative number (%d) not allowed: %s", num, expr) + } + + return uint(num) +} + +// getBits sets all bits in the range [min, max], modulo the given step size. +func getBits(min, max, step uint) uint64 { + var bits uint64 + + // If step is 1, use shifts. + if step == 1 { + return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) + } + + // Else, use a simple loop. + for i := min; i <= max; i += step { + bits |= 1 << i + } + return bits +} + +// all returns all bits within the given bounds. (plus the star bit) +func all(r bounds) uint64 { + return getBits(r.min, r.max, 1) | starBit +} + +func init() { + AdminTaskList = make(map[string]Tasker) + stop = make(chan bool) + changed = make(chan bool) +} diff --git a/pkg/toolbox/task_test.go b/pkg/toolbox/task_test.go new file mode 100644 index 00000000..3a4cce2f --- /dev/null +++ b/pkg/toolbox/task_test.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func() error { + cnt ++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200 ; i ++ { + e := tk.Run() + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/tree.go b/pkg/tree.go new file mode 100644 index 00000000..9e53003b --- /dev/null +++ b/pkg/tree.go @@ -0,0 +1,585 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "path" + "regexp" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +var ( + allowSuffixExt = []string{".json", ".xml", ".html"} +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree struct { + //prefix set for static router + prefix string + //search fix route first + fixrouters []*Tree + //if set, failure to match fixrouters search then search wildcard + wildcard *Tree + //if set, failure to match wildcard search + leaves []*leafInfo +} + +// NewTree return a new Tree +func NewTree() *Tree { + return &Tree{} +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + t.addtree(splitPath(prefix), tree, nil, "") +} + +func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) { + if len(segments) == 0 { + panic("prefix should has path") + } + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + params = params[1:] + if len(segments[1:]) > 0 { + t.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + filterTreeWithPrefix(tree, wildcards, reg) + } + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if len(segments) == 1 { + if iswild { + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + } else { + for _, w := range params { + if w == "." || w == ":" { + continue + } + regexpStr = "([^/]+)/" + regexpStr + } + } + } + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + t.wildcard = tree + } else { + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + tree.prefix = seg + t.fixrouters = append(t.fixrouters, tree) + } + return + } + + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "([^/]+)/" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/") + t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + subTree := NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + subTree.addtree(segments[1:], tree, append(wildcards, params...), reg) + } +} + +func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { + for _, v := range t.fixrouters { + filterTreeWithPrefix(v, wildcards, reg) + } + if t.wildcard != nil { + filterTreeWithPrefix(t.wildcard, wildcards, reg) + } + for _, l := range t.leaves { + if reg != "" { + if l.regexps != nil { + l.wildcards = append(wildcards, l.wildcards...) + l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$") + } else { + for _, v := range l.wildcards { + if v == ":splat" { + reg = reg + "/(.+)" + } else { + reg = reg + "/([^/]+)" + } + } + l.regexps = regexp.MustCompile("^" + reg + "$") + l.wildcards = append(wildcards, l.wildcards...) + } + } else { + l.wildcards = append(wildcards, l.wildcards...) + if l.regexps != nil { + for _, w := range wildcards { + if w == ":splat" { + reg = "(.+)/" + reg + } else { + reg = "([^/]+)/" + reg + } + } + l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$") + } + } + } +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + t.addseg(splitPath(pattern), runObject, nil, "") +} + +// "/" +// "admin" -> +func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { + if len(segments) == 0 { + if reg != "" { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + } else { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + } + } else { + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + t.addseg(segments[1:], route, wildcards, reg) + params = params[1:] + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "/([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "/([^/]+)" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr) + } else { + var subTree *Tree + for _, sub := range t.fixrouters { + if sub.prefix == seg { + subTree = sub + break + } + } + if subTree == nil { + subTree = NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + } + subTree.addseg(segments[1:], route, wildcards, reg) + } + } +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + if len(pattern) == 0 || pattern[0] != '/' { + return nil + } + w := make([]string, 0, 20) + return t.match(pattern[1:], pattern, w, ctx) +} + +func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { + if len(pattern) > 0 { + i := 0 + for ; i < len(pattern) && pattern[i] == '/'; i++ { + } + pattern = pattern[i:] + } + // Handle leaf nodes: + if len(pattern) == 0 { + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + if t.wildcard != nil { + for _, l := range t.wildcard.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return nil + } + var seg string + i, l := 0, len(pattern) + for ; i < l && pattern[i] != '/'; i++ { + } + if i == 0 { + seg = pattern + pattern = "" + } else { + seg = pattern[:i] + pattern = pattern[i:] + } + for _, subTree := range t.fixrouters { + if subTree.prefix == seg { + if len(pattern) != 0 && pattern[0] == '/' { + treePattern = pattern[1:] + } else { + treePattern = pattern + } + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + break + } + } + } + if runObject == nil && len(t.fixrouters) > 0 { + // Filter the .json .xml .html extension + for _, str := range allowSuffixExt { + if strings.HasSuffix(seg, str) { + for _, subTree := range t.fixrouters { + if subTree.prefix == seg[:len(seg)-len(str)] { + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + ctx.Input.SetParam(":ext", str[1:]) + } + } + } + } + } + } + if runObject == nil && t.wildcard != nil { + runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) + } + + if runObject == nil && len(t.leaves) > 0 { + wildcardValues = append(wildcardValues, seg) + start, i := 0, 0 + for ; i < len(pattern); i++ { + if pattern[i] == '/' { + if i != 0 && start < len(pattern) { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + start = i + 1 + continue + } + } + if start > 0 { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return runObject +} + +type leafInfo struct { + // names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name" + wildcards []string + + // if the leaf is regexp + regexps *regexp.Regexp + + runObject interface{} +} + +func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { + //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + if leaf.regexps == nil { + if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path + return true + } + // match * + if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { + ctx.Input.SetParam(":splat", treePattern) + return true + } + // match *.* or :id + if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" { + if len(leaf.wildcards) == 2 { + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0])) + return true + } else if len(wildcardValues) < 2 { + return false + } + var index int + for index = 0; index < len(leaf.wildcards)-2; index++ { + ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index]) + } + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + if index > (len(wildcardValues) - 1) { + ctx.Input.SetParam(":path", "") + } else { + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0])) + } + return true + } + // match :id + if len(leaf.wildcards) != len(wildcardValues) { + return false + } + for j, v := range leaf.wildcards { + ctx.Input.SetParam(v, wildcardValues[j]) + } + return true + } + + if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { + return false + } + matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) + for i, match := range matches[1:] { + if i < len(leaf.wildcards) { + ctx.Input.SetParam(leaf.wildcards[i], match) + } + } + return true +} + +// "/" -> [] +// "/admin" -> ["admin"] +// "/admin/" -> ["admin"] +// "/admin/users" -> ["admin", "users"] +func splitPath(key string) []string { + key = strings.Trim(key, "/ ") + if key == "" { + return []string{} + } + return strings.Split(key, "/") +} + +// "admin" -> false, nil, "" +// ":id" -> true, [:id], "" +// "?:id" -> true, [: :id], "" : meaning can empty +// ":id:int" -> true, [:id], ([0-9]+) +// ":name:string" -> true, [:name], ([\w]+) +// ":id([0-9]+)" -> true, [:id], ([0-9]+) +// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) +// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html +// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html +// "*" -> true, [:splat], "" +// "*.*" -> true,[. :path :ext], "" . meaning separator +func splitSegment(key string) (bool, []string, string) { + if strings.HasPrefix(key, "*") { + if key == "*.*" { + return true, []string{".", ":path", ":ext"}, "" + } + return true, []string{":splat"}, "" + } + if strings.ContainsAny(key, ":") { + var paramsNum int + var out []rune + var start bool + var startexp bool + var param []rune + var expt []rune + var skipnum int + params := []string{} + reg := regexp.MustCompile(`[a-zA-Z0-9_]+`) + for i, v := range key { + if skipnum > 0 { + skipnum-- + continue + } + if start { + //:id:int and :name:string + if v == ':' { + if len(key) >= i+4 { + if key[i+1:i+4] == "int" { + out = append(out, []rune("([0-9]+)")...) + params = append(params, ":"+string(param)) + start = false + startexp = false + skipnum = 3 + param = make([]rune, 0) + paramsNum++ + continue + } + } + if len(key) >= i+7 { + if key[i+1:i+7] == "string" { + out = append(out, []rune(`([\w]+)`)...) + params = append(params, ":"+string(param)) + paramsNum++ + start = false + startexp = false + skipnum = 6 + param = make([]rune, 0) + continue + } + } + } + // params only support a-zA-Z0-9 + if reg.MatchString(string(v)) { + param = append(param, v) + continue + } + if v != '(' { + out = append(out, []rune(`(.+)`)...) + params = append(params, ":"+string(param)) + param = make([]rune, 0) + paramsNum++ + start = false + startexp = false + } + } + if startexp { + if v != ')' { + expt = append(expt, v) + continue + } + } + // Escape Sequence '\' + if i > 0 && key[i-1] == '\\' { + out = append(out, v) + } else if v == ':' { + param = make([]rune, 0) + start = true + } else if v == '(' { + startexp = true + start = false + if len(param) > 0 { + params = append(params, ":"+string(param)) + param = make([]rune, 0) + } + paramsNum++ + expt = make([]rune, 0) + expt = append(expt, '(') + } else if v == ')' { + startexp = false + expt = append(expt, ')') + out = append(out, expt...) + param = make([]rune, 0) + } else if v == '?' { + params = append(params, ":") + } else { + out = append(out, v) + } + } + if len(param) > 0 { + if paramsNum > 0 { + out = append(out, []rune(`(.+)`)...) + } + params = append(params, ":"+string(param)) + } + return true, params, string(out) + } + return false, nil, "" +} diff --git a/pkg/tree_test.go b/pkg/tree_test.go new file mode 100644 index 00000000..d412a348 --- /dev/null +++ b/pkg/tree_test.go @@ -0,0 +1,306 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + "testing" + + "github.com/astaxie/beego/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset(ctx) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset(ctx) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset(ctx) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} + +func TestSplitPath(t *testing.T) { + a := splitPath("") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/admin") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin should retrun [admin]") + } + a = splitPath("/admin/") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin/ should retrun [admin]") + } + a = splitPath("/admin/users") + if len(a) != 2 || a[0] != "admin" || a[1] != "users" { + t.Fatal("/admin should retrun [admin users]") + } + a = splitPath("/admin/:id:int") + if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { + t.Fatal("/admin should retrun [admin :id:int]") + } +} + +func TestSplitSegment(t *testing.T) { + + items := map[string]struct { + isReg bool + params []string + regStr string + }{ + "admin": {false, nil, ""}, + "*": {true, []string{":splat"}, ""}, + "*.*": {true, []string{".", ":path", ":ext"}, ""}, + ":id": {true, []string{":id"}, ""}, + "?:id": {true, []string{":", ":id"}, ""}, + ":id:int": {true, []string{":id"}, "([0-9]+)"}, + ":name:string": {true, []string{":name"}, `([\w]+)`}, + ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, + ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, + ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, + "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, + `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, + `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, + } + + for pattern, v := range items { + b, w, r := splitSegment(pattern) + if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { + t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) + } + } +} diff --git a/pkg/unregroute_test.go b/pkg/unregroute_test.go new file mode 100644 index 00000000..08b1b77b --- /dev/null +++ b/pkg/unregroute_test.go @@ -0,0 +1,226 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// +// The unregroute_test.go contains tests for the unregister route +// functionality, that allows overriding route paths in children project +// that embed parent routers. +// + +const contentRootOriginal = "ok-original-root" +const contentLevel1Original = "ok-original-level1" +const contentLevel2Original = "ok-original-level2" + +const contentRootReplacement = "ok-replacement-root" +const contentLevel1Replacement = "ok-replacement-level1" +const contentLevel2Replacement = "ok-replacement-level2" + +// TestPreUnregController will supply content for the original routes, +// before unregistration +type TestPreUnregController struct { + Controller +} + +func (tc *TestPreUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootOriginal)) +} +func (tc *TestPreUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Original)) +} +func (tc *TestPreUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Original)) +} + +// TestPostUnregController will supply content for the overriding routes, +// after the original ones are unregistered. +type TestPostUnregController struct { + Controller +} + +func (tc *TestPostUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootReplacement)) +} +func (tc *TestPostUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) +} +func (tc *TestPostUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) +} + +// TestUnregisterFixedRouteRoot replaces just the root fixed route path. +// In this case, for a path like "/level1/level2" or "/level1", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteRoot(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, "Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, "Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, "Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the root path + findAndRemoveSingleTree(handler.routers[method]) + + // Replace the root path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + + // Test replacement root (expect change) + testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. +// In this case, for a path like "/level1/level2" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel1(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level1 path + subPaths := splitPath("/level1") + if handler.routers[method].prefix == strings.Trim("/level1", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "level1" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect change) + testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed +// route path. In this case, for a path like "/level1" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel2(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level2 path + subPaths := splitPath("/level1/level2") + if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "/level1/level2" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect change) + testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) + +} + +func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, + testName, method, path, expectedBodyContent string) { + + r, err := http.NewRequest(method, path, nil) + if err != nil { + t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) + return + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != expectedBodyContent { + t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body) + } +} diff --git a/pkg/utils/caller.go b/pkg/utils/caller.go new file mode 100644 index 00000000..73c52a62 --- /dev/null +++ b/pkg/utils/caller.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "reflect" + "runtime" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} diff --git a/pkg/utils/caller_test.go b/pkg/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/utils/captcha/LICENSE b/pkg/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/utils/captcha/README.md b/pkg/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/utils/captcha/captcha.go b/pkg/utils/captcha/captcha.go new file mode 100644 index 00000000..42ac70d3 --- /dev/null +++ b/pkg/utils/captcha/captcha.go @@ -0,0 +1,270 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "fmt" + "html/template" + "net/http" + "path" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha struct { + // beego cache store + store cache.Cache + + // url prefix for captcha image + URLPrefix string + + // specify captcha id input field name + FieldIDName string + // specify captcha result input field name + FieldCaptchaName string + + // captcha image width and height + StdWidth int + StdHeight int + + // captcha chars nums + ChallengeNums int + + // captcha expiration seconds + Expiration time.Duration + + // cache key prefix + CachePrefix string +} + +// generate key string +func (c *Captcha) key(id string) string { + return c.CachePrefix + id +} + +// generate rand chars with default chars +func (c *Captcha) genRandChars() []byte { + return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) +} + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + var chars []byte + + id := path.Base(ctx.Request.RequestURI) + if i := strings.Index(id, "."); i != -1 { + id = id[:i] + } + + key := c.key(id) + + if len(ctx.Input.Query("reload")) > 0 { + chars = c.genRandChars() + if err := c.store.Put(key, chars, c.Expiration); err != nil { + ctx.Output.SetStatus(500) + ctx.WriteString("captcha reload error") + logs.Error("Reload Create Captcha Error:", err) + return + } + } else { + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + ctx.Output.SetStatus(404) + ctx.WriteString("captcha not found") + return + } + } + + img := NewImage(chars, c.StdWidth, c.StdHeight) + if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { + logs.Error("Write Captcha Image Error:", err) + } +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + value, err := c.CreateCaptcha() + if err != nil { + logs.Error("Create Captcha Error:", err) + return "" + } + + // create html + return template.HTML(fmt.Sprintf(``+ + ``+ + ``+ + ``, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value)) +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + // generate captcha id + id := string(utils.RandomCreateBytes(15)) + + // get the captcha chars + chars := c.genRandChars() + + // save to store + if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + return "", err + } + + return id, nil +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + req.ParseForm() + return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName)) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + if len(challenge) == 0 || len(id) == 0 { + return + } + + var chars []byte + + key := c.key(id) + + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + return + } + + defer func() { + // finally remove it + c.store.Delete(key) + }() + + if len(chars) != len(challenge) { + return + } + // verify challenge + for i, c := range chars { + if c != challenge[i]-48 { + return + } + } + + return true +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + cpt := &Captcha{} + cpt.store = store + cpt.FieldIDName = fieldIDName + cpt.FieldCaptchaName = fieldCaptchaName + cpt.ChallengeNums = challengeNums + cpt.Expiration = expiration + cpt.CachePrefix = cachePrefix + cpt.StdWidth = stdWidth + cpt.StdHeight = stdHeight + + if len(urlPrefix) == 0 { + urlPrefix = defaultURLPrefix + } + + if urlPrefix[len(urlPrefix)-1] != '/' { + urlPrefix += "/" + } + + cpt.URLPrefix = urlPrefix + + return cpt +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + cpt := NewCaptcha(urlPrefix, store) + + // create filter for serve captcha image + beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) + + // add to template func map + beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) + + return cpt +} diff --git a/pkg/utils/captcha/image.go b/pkg/utils/captcha/image.go new file mode 100644 index 00000000..c3c9a83a --- /dev/null +++ b/pkg/utils/captcha/image.go @@ -0,0 +1,501 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "bytes" + "image" + "image/color" + "image/png" + "io" + "math" +) + +const ( + fontWidth = 11 + fontHeight = 18 + blackChar = 1 + + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 + // Maximum absolute skew factor of a single digit. + maxSkew = 0.7 + // Number of background circles. + circleCount = 20 +) + +var font = [][]byte{ + { // 0 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 1 + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 2 + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 3 + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 4 + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + }, + { // 5 + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 6 + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 7 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + }, + { // 8 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 9 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + }, +} + +// Image struct +type Image struct { + *image.Paletted + numWidth int + numHeight int + dotSize int +} + +var prng = &siprng{} + +// randIntn returns a pseudorandom non-negative int in range [0, n). +func randIntn(n int) int { + return prng.Intn(n) +} + +// randInt returns a pseudorandom int in range [from, to]. +func randInt(from, to int) int { + return prng.Intn(to+1-from) + from +} + +// randFloat returns a pseudorandom float64 in range [from, to]. +func randFloat(from, to float64) float64 { + return (to-from)*prng.Float64() + from +} + +func randomPalette() color.Palette { + p := make([]color.Color, circleCount+1) + // Transparent color. + p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} + // Primary color. + prim := color.RGBA{ + uint8(randIntn(129)), + uint8(randIntn(129)), + uint8(randIntn(129)), + 0xFF, + } + p[1] = prim + // Circle colors. + for i := 2; i <= circleCount; i++ { + p[i] = randomBrightness(prim, 255) + } + return p +} + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + m := new(Image) + m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) + m.calculateSizes(width, height, len(digits)) + // Randomly position captcha inside the image. + maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize + maxy := height - m.numHeight - m.dotSize*2 + var border int + if width > height { + border = height / 5 + } else { + border = width / 5 + } + x := randInt(border, maxx-border) + y := randInt(border, maxy-border) + // Draw digits. + for _, n := range digits { + m.drawDigit(font[n], x, y) + x += m.numWidth + m.dotSize + } + // Draw strike-through line. + m.strikeThrough() + // Apply wave distortion. + m.distort(randFloat(5, 10), randFloat(100, 200)) + // Fill image with random circles. + m.fillWithCircles(circleCount, m.dotSize) + return m +} + +// encodedPNG encodes an image to PNG and returns +// the result as a byte slice. +func (m *Image) encodedPNG() []byte { + var buf bytes.Buffer + if err := png.Encode(&buf, m.Paletted); err != nil { + panic(err.Error()) + } + return buf.Bytes() +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.encodedPNG()) + return int64(n), err +} + +func (m *Image) calculateSizes(width, height, ncount int) { + // Goal: fit all digits inside the image. + var border int + if width > height { + border = height / 4 + } else { + border = width / 4 + } + // Convert everything to floats for calculations. + w := float64(width - border*2) + h := float64(height - border*2) + // fw takes into account 1-dot spacing between digits. + fw := float64(fontWidth + 1) + fh := float64(fontHeight) + nc := float64(ncount) + // Calculate the width of a single digit taking into account only the + // width of the image. + nw := w / nc + // Calculate the height of a digit from this width. + nh := nw * fh / fw + // Digit too high? + if nh > h { + // Fit digits based on height. + nh = h + nw = fw / fh * nh + } + // Calculate dot size. + m.dotSize = int(nh / fh) + if m.dotSize < 1 { + m.dotSize = 1 + } + // Save everything, making the actual width smaller by 1 dot to account + // for spacing between digits. + m.numWidth = int(nw) - m.dotSize + m.numHeight = int(nh) +} + +func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { + for x := fromX; x <= toX; x++ { + m.SetColorIndex(x, y, colorIdx) + } +} + +func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { + f := 1 - radius + dfx := 1 + dfy := -2 * radius + xo := 0 + yo := radius + + m.SetColorIndex(x, y+radius, colorIdx) + m.SetColorIndex(x, y-radius, colorIdx) + m.drawHorizLine(x-radius, x+radius, y, colorIdx) + + for xo < yo { + if f >= 0 { + yo-- + dfy += 2 + f += dfy + } + xo++ + dfx += 2 + f += dfx + m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) + m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) + } +} + +func (m *Image) fillWithCircles(n, maxradius int) { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + for i := 0; i < n; i++ { + colorIdx := uint8(randInt(1, circleCount-1)) + r := randInt(1, maxradius) + m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) + } +} + +func (m *Image) strikeThrough() { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + y := randInt(maxy/3, maxy-maxy/3) + amplitude := randFloat(5, 20) + period := randFloat(80, 180) + dx := 2.0 * math.Pi / period + for x := 0; x < maxx; x++ { + xo := amplitude * math.Cos(float64(y)*dx) + yo := amplitude * math.Sin(float64(x)*dx) + for yn := 0; yn < m.dotSize; yn++ { + r := randInt(0, m.dotSize) + m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) + } + } +} + +func (m *Image) drawDigit(digit []byte, x, y int) { + skf := randFloat(-maxSkew, maxSkew) + xs := float64(x) + r := m.dotSize / 2 + y += randInt(-r, r) + for yo := 0; yo < fontHeight; yo++ { + for xo := 0; xo < fontWidth; xo++ { + if digit[yo*fontWidth+xo] != blackChar { + continue + } + m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) + } + xs += skf + x = int(xs) + } +} + +func (m *Image) distort(amplude float64, period float64) { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + + oldm := m.Paletted + newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) + + dx := 2.0 * math.Pi / period + for x := 0; x < w; x++ { + for y := 0; y < h; y++ { + xo := amplude * math.Sin(float64(y)*dx) + yo := amplude * math.Cos(float64(x)*dx) + newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) + } + } + m.Paletted = newm +} + +func randomBrightness(c color.RGBA, max uint8) color.RGBA { + minc := min3(c.R, c.G, c.B) + maxc := max3(c.R, c.G, c.B) + if maxc > max { + return c + } + n := randIntn(int(max-maxc)) - int(minc) + return color.RGBA{ + uint8(int(c.R) + n), + uint8(int(c.G) + n), + uint8(int(c.B) + n), + c.A, + } +} + +func min3(x, y, z uint8) (m uint8) { + m = x + if y < m { + m = y + } + if z < m { + m = z + } + return +} + +func max3(x, y, z uint8) (m uint8) { + m = x + if y > m { + m = y + } + if z > m { + m = z + } + return +} diff --git a/pkg/utils/captcha/image_test.go b/pkg/utils/captcha/image_test.go new file mode 100644 index 00000000..5e35b7f7 --- /dev/null +++ b/pkg/utils/captcha/image_test.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/utils" +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/utils/captcha/siprng.go b/pkg/utils/captcha/siprng.go new file mode 100644 index 00000000..5e256cf9 --- /dev/null +++ b/pkg/utils/captcha/siprng.go @@ -0,0 +1,277 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "crypto/rand" + "encoding/binary" + "io" + "sync" +) + +// siprng is PRNG based on SipHash-2-4. +type siprng struct { + mu sync.Mutex + k0, k1, ctr uint64 +} + +// siphash implements SipHash-2-4, accepting a uint64 as a message. +func siphash(k0, k1, m uint64) uint64 { + // Initialization. + v0 := k0 ^ 0x736f6d6570736575 + v1 := k1 ^ 0x646f72616e646f6d + v2 := k0 ^ 0x6c7967656e657261 + v3 := k1 ^ 0x7465646279746573 + t := uint64(8) << 56 + + // Compression. + v3 ^= m + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= m + + // Compress last block. + v3 ^= t + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= t + + // Finalization. + v2 ^= 0xff + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 3. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 4. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + return v0 ^ v1 ^ v2 ^ v3 +} + +// rekey sets a new PRNG key, which is read from crypto/rand. +func (p *siprng) rekey() { + var k [16]byte + if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { + panic(err.Error()) + } + p.k0 = binary.LittleEndian.Uint64(k[0:8]) + p.k1 = binary.LittleEndian.Uint64(k[8:16]) + p.ctr = 1 +} + +// Uint64 returns a new pseudorandom uint64. +// It rekeys PRNG on the first call and every 64 MB of generated data. +func (p *siprng) Uint64() uint64 { + p.mu.Lock() + if p.ctr == 0 || p.ctr > 8*1024*1024 { + p.rekey() + } + v := siphash(p.k0, p.k1, p.ctr) + p.ctr++ + p.mu.Unlock() + return v +} + +func (p *siprng) Int63() int64 { + return int64(p.Uint64() & 0x7fffffffffffffff) +} + +func (p *siprng) Uint32() uint32 { + return uint32(p.Uint64()) +} + +func (p *siprng) Int31() int32 { + return int32(p.Uint32() & 0x7fffffff) +} + +func (p *siprng) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + if n <= 1<<31-1 { + return int(p.Int31n(int32(n))) + } + return int(p.Int63n(int64(n))) +} + +func (p *siprng) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) + v := p.Int63() + for v > max { + v = p.Int63() + } + return v % n +} + +func (p *siprng) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := p.Int31() + for v > max { + v = p.Int31() + } + return v % n +} + +func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/pkg/utils/captcha/siprng_test.go b/pkg/utils/captcha/siprng_test.go new file mode 100644 index 00000000..189d3d3c --- /dev/null +++ b/pkg/utils/captcha/siprng_test.go @@ -0,0 +1,33 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import "testing" + +func TestSiphash(t *testing.T) { + good := uint64(0xe849e8bb6ffe2567) + cur := siphash(0, 0, 0) + if cur != good { + t.Fatalf("siphash: expected %x, got %x", good, cur) + } +} + +func BenchmarkSiprng(b *testing.B) { + b.SetBytes(8) + p := &siprng{} + for i := 0; i < b.N; i++ { + p.Uint64() + } +} diff --git a/pkg/utils/debug.go b/pkg/utils/debug.go new file mode 100644 index 00000000..93c27b70 --- /dev/null +++ b/pkg/utils/debug.go @@ -0,0 +1,478 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "fmt" + "log" + "reflect" + "runtime" +) + +var ( + dunno = []byte("???") + centerDot = []byte("·") + dot = []byte(".") +) + +type pointerInfo struct { + prev *pointerInfo + n int + addr uintptr + pos int + used []int +} + +// Display print the data in console +func Display(data ...interface{}) { + display(true, data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return display(false, data...) +} + +func display(displayed bool, data ...interface{}) string { + var pc, file, line, ok = runtime.Caller(2) + + if !ok { + return "" + } + + var buf = new(bytes.Buffer) + + fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) + + fmt.Fprintf(buf, "\n[Variables]\n") + + for i := 0; i < len(data); i += 2 { + var output = fomateinfo(len(data[i].(string))+3, data[i+1]) + fmt.Fprintf(buf, "%s = %s", data[i], output) + } + + if displayed { + log.Print(buf) + } + return buf.String() +} + +// return data dump and format bytes +func fomateinfo(headlen int, data ...interface{}) []byte { + var buf = new(bytes.Buffer) + + if len(data) > 1 { + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "[") + + fmt.Fprintln(buf) + } + + for k, v := range data { + var buf2 = new(bytes.Buffer) + var pointers *pointerInfo + var interfaces = make([]reflect.Value, 0, 10) + + printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) + + if k < len(data)-1 { + fmt.Fprint(buf2, ", ") + } + + fmt.Fprintln(buf2) + + buf.Write(buf2.Bytes()) + } + + if len(data) > 1 { + fmt.Fprintln(buf) + + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "]") + } + + return buf.Bytes() +} + +// check data is golang basic type +func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { + switch kind { + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Chan: + return true + case reflect.Invalid: + return true + case reflect.Interface: + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + return true + } + } + return false + case reflect.UnsafePointer: + if val.IsNil() { + return true + } + + var elem = val.Elem() + + if isSimpleType(elem, elem.Kind(), pointers, interfaces) { + return true + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + return true + } + } + + return false + } + + return false +} + +// dump value +func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { + var t = val.Kind() + + switch t { + case reflect.Bool: + fmt.Fprint(buf, val.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fmt.Fprint(buf, val.Int()) + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + fmt.Fprint(buf, val.Uint()) + case reflect.Float32, reflect.Float64: + fmt.Fprint(buf, val.Float()) + case reflect.Complex64, reflect.Complex128: + fmt.Fprint(buf, val.Complex()) + case reflect.UnsafePointer: + fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer()) + case reflect.Ptr: + if val.IsNil() { + fmt.Fprint(buf, "nil") + return + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + p.used = append(p.used, buf.Len()) + fmt.Fprintf(buf, "0x%X", addr) + return + } + } + + *pointers = &pointerInfo{ + prev: *pointers, + addr: addr, + pos: buf.Len(), + used: make([]int, 0), + } + + fmt.Fprint(buf, "&") + + printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level) + case reflect.String: + fmt.Fprint(buf, "\"", val.String(), "\"") + case reflect.Interface: + var value = val.Elem() + + if !value.IsValid() { + fmt.Fprint(buf, "nil") + } else { + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + fmt.Fprint(buf, "repeat") + return + } + } + + *interfaces = append(*interfaces, val) + + printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + case reflect.Struct: + var t = val.Type() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + for i := 0; i < val.NumField(); i++ { + if formatOutput { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + var name = t.Field(i).Name + + if formatOutput { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + fmt.Fprint(buf, name) + fmt.Fprint(buf, ": ") + + if structFilter != nil && structFilter(t.String(), name) { + fmt.Fprint(buf, "ignore") + } else { + printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + + fmt.Fprint(buf, ",") + } + + if formatOutput { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Array, reflect.Slice: + fmt.Fprint(buf, val.Type()) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < val.Len(); i++ { + var elem = val.Index(i) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Map: + var t = val.Type() + var keys = val.MapKeys() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < len(keys); i++ { + var elem = val.MapIndex(keys[i]) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind <= level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1) + fmt.Fprint(buf, ": ") + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Chan: + fmt.Fprint(buf, val.Type()) + case reflect.Invalid: + fmt.Fprint(buf, "invalid") + default: + fmt.Fprint(buf, "unknow") + } +} + +// PrintPointerInfo dump pointer value +func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { + var anyused = false + var pointerNum = 0 + + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 { + anyused = true + } + pointerNum++ + p.n = pointerNum + } + + if anyused { + var pointerBufs = make([][]rune, pointerNum+1) + + for i := 0; i < len(pointerBufs); i++ { + var pointerBuf = make([]rune, buf.Len()+headlen) + + for j := 0; j < len(pointerBuf); j++ { + pointerBuf[j] = ' ' + } + + pointerBufs[i] = pointerBuf + } + + for pn := 0; pn <= pointerNum; pn++ { + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 && p.n >= pn { + if pn == p.n { + pointerBufs[pn][p.pos+headlen] = '└' + + var maxpos = 0 + + for i, pos := range p.used { + if i < len(p.used)-1 { + pointerBufs[pn][pos+headlen] = '┴' + } else { + pointerBufs[pn][pos+headlen] = '┘' + } + + maxpos = pos + } + + for i := 0; i < maxpos-p.pos-1; i++ { + if pointerBufs[pn][i+p.pos+headlen+1] == ' ' { + pointerBufs[pn][i+p.pos+headlen+1] = '─' + } + } + } else { + pointerBufs[pn][p.pos+headlen] = '│' + + for _, pos := range p.used { + if pointerBufs[pn][pos+headlen] == ' ' { + pointerBufs[pn][pos+headlen] = '│' + } else { + pointerBufs[pn][pos+headlen] = '┼' + } + } + } + } + } + + buf.WriteString(string(pointerBufs[pn]) + "\n") + } + } +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + var buf = new(bytes.Buffer) + + for i := skip; ; i++ { + var pc, file, line, ok = runtime.Caller(i) + + if !ok { + break + } + + buf.WriteString(indent) + + fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line) + } + + return buf.Bytes() +} + +// return the name of the function containing the PC if possible, +func function(pc uintptr) []byte { + fn := runtime.FuncForPC(pc) + if fn == nil { + return dunno + } + name := []byte(fn.Name()) + // The name includes the path name to the package, which is unnecessary + // since the file name is already included. Plus, it has center dots. + // That is, we see + // runtime/debug.*T·ptrmethod + // and want + // *T.ptrmethod + if period := bytes.Index(name, dot); period >= 0 { + name = name[period+1:] + } + name = bytes.Replace(name, centerDot, dot, -1) + return name +} diff --git a/pkg/utils/debug_test.go b/pkg/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 00000000..6090eb17 --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bufio" + "errors" + "io" + "os" + "path/filepath" + "regexp" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + path, _ := filepath.Abs(os.Args[0]) + return path +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return filepath.Dir(SelfPath()) +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + if _, err := os.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + for _, path := range paths { + if fullpath = filepath.Join(path, filename); FileExists(fullpath) { + return + } + } + err = errors.New(fullpath + " not found in paths") + return +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + re, err := regexp.Compile(patten) + if err != nil { + return + } + + fd, err := os.Open(filename) + if err != nil { + return + } + lines = make([]string, 0) + reader := bufio.NewReader(fd) + prefix := "" + var isLongLine bool + for { + byteLine, isPrefix, er := reader.ReadLine() + if er != nil && er != io.EOF { + return nil, er + } + if er == io.EOF { + break + } + line := string(byteLine) + if isPrefix { + prefix += line + continue + } else { + isLongLine = true + } + + line = prefix + line + if isLongLine { + prefix = "" + } + if re.MatchString(line) { + lines = append(lines, line) + } + } + return lines, nil +} diff --git a/pkg/utils/file_test.go b/pkg/utils/file_test.go new file mode 100644 index 00000000..b2644157 --- /dev/null +++ b/pkg/utils/file_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + "reflect" + "testing" +) + +var noExistedFile = "/tmp/not_existed_file" + +func TestSelfPath(t *testing.T) { + path := SelfPath() + if path == "" { + t.Error("path cannot be empty") + } + t.Logf("SelfPath: %s", path) +} + +func TestSelfDir(t *testing.T) { + dir := SelfDir() + t.Logf("SelfDir: %s", dir) +} + +func TestFileExists(t *testing.T) { + if !FileExists("./file.go") { + t.Errorf("./file.go should exists, but it didn't") + } + + if FileExists(noExistedFile) { + t.Errorf("Weird, how could this file exists: %s", noExistedFile) + } +} + +func TestSearchFile(t *testing.T) { + path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) + if err != nil { + t.Error(err) + } + t.Log(path) + + _, err = SearchFile(noExistedFile, ".") + if err == nil { + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) + } +} + +func TestGrepFile(t *testing.T) { + _, err := GrepFile("", noExistedFile) + if err == nil { + t.Error("expect file-not-existed error, but got nothing") + } + + path := filepath.Join(".", "testdata", "grepe.test") + lines, err := GrepFile(`^\s*[^#]+`, path) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { + t.Errorf("expect [hello world], but receive %v", lines) + } +} diff --git a/pkg/utils/mail.go b/pkg/utils/mail.go new file mode 100644 index 00000000..80a366ca --- /dev/null +++ b/pkg/utils/mail.go @@ -0,0 +1,424 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/mail" + "net/smtp" + "net/textproto" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" +) + +const ( + maxLineLength = 76 + + upperhex = "0123456789ABCDEF" +) + +// Email is the type used for email messages +type Email struct { + Auth smtp.Auth + Identity string `json:"identity"` + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Port int `json:"port"` + From string `json:"from"` + To []string + Bcc []string + Cc []string + Subject string + Text string // Plaintext message (optional) + HTML string // Html message (optional) + Headers textproto.MIMEHeader + Attachments []*Attachment + ReadReceipt []string +} + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment struct { + Filename string + Header textproto.MIMEHeader + Content []byte +} + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + e := new(Email) + e.Headers = textproto.MIMEHeader{} + err := json.Unmarshal([]byte(config), e) + if err != nil { + return nil + } + return e +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + buff := &bytes.Buffer{} + w := multipart.NewWriter(buff) + // Set the appropriate headers (overwriting any conflicts) + // Leave out Bcc (only included in envelope headers) + e.Headers.Set("To", strings.Join(e.To, ",")) + if e.Cc != nil { + e.Headers.Set("Cc", strings.Join(e.Cc, ",")) + } + e.Headers.Set("From", e.From) + e.Headers.Set("Subject", e.Subject) + if len(e.ReadReceipt) != 0 { + e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) + } + e.Headers.Set("MIME-Version", "1.0") + + // Write the envelope headers (including any custom headers) + if err := headerToBytes(buff, e.Headers); err != nil { + return nil, fmt.Errorf("Failed to render message headers: %s", err) + } + + e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + fmt.Fprintf(buff, "%s:", "Content-Type") + fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + + // Start the multipart/mixed part + fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) + header := textproto.MIMEHeader{} + // Check to see if there is a Text or HTML field + if e.Text != "" || e.HTML != "" { + subWriter := multipart.NewWriter(buff) + // Create the multipart alternative part + header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) + // Write the header + if err := headerToBytes(buff, header); err != nil { + return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) + } + // Create the body sections + if e.Text != "" { + header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.Text); err != nil { + return nil, err + } + } + if e.HTML != "" { + header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.HTML); err != nil { + return nil, err + } + } + if err := subWriter.Close(); err != nil { + return nil, err + } + } + // Create attachment part, if necessary + for _, a := range e.Attachments { + ap, err := w.CreatePart(a.Header) + if err != nil { + return nil, err + } + // Write the base64Wrapped content to the part + base64Wrap(ap, a.Content) + } + if err := w.Close(); err != nil { + return nil, err + } + return buff.Bytes(), nil +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify a file name and number of parameters can not exceed at least two") + return + } + filename := args[0] + id := "" + if len(args) > 1 { + id = args[1] + } + f, err := os.Open(filename) + if err != nil { + return + } + defer f.Close() + ct := mime.TypeByExtension(filepath.Ext(filename)) + basename := path.Base(filename) + return e.Attach(f, basename, ct, id) +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify the file type and number of parameters can not exceed at least two") + return + } + c := args[0] //Content-Type + id := "" + if len(args) > 1 { + id = args[1] //Content-ID + } + var buffer bytes.Buffer + if _, err = io.Copy(&buffer, r); err != nil { + return + } + at := &Attachment{ + Filename: filename, + Header: textproto.MIMEHeader{}, + Content: buffer.Bytes(), + } + // Get the Content-Type to be used in the MIMEHeader + if c != "" { + at.Header.Set("Content-Type", c) + } else { + // If the Content-Type is blank, set the Content-Type to "application/octet-stream" + at.Header.Set("Content-Type", "application/octet-stream") + } + if id != "" { + at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) + at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) + } else { + at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + } + at.Header.Set("Content-Transfer-Encoding", "base64") + e.Attachments = append(e.Attachments, at) + return at, nil +} + +// Send will send out the mail +func (e *Email) Send() error { + if e.Auth == nil { + e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) + } + // Merge the To, Cc, and Bcc fields + to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) + to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) + // Check to make sure there is at least one recipient and one "From" address + if len(to) == 0 { + return errors.New("Must specify at least one To address") + } + + // Use the username if no From is provided + if len(e.From) == 0 { + e.From = e.Username + } + + from, err := mail.ParseAddress(e.From) + if err != nil { + return err + } + + // use mail's RFC 2047 to encode any string + e.Subject = qEncode("utf-8", e.Subject) + + raw, err := e.Bytes() + if err != nil { + return err + } + return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) +} + +// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) +func quotePrintEncode(w io.Writer, s string) error { + var buf [3]byte + mc := 0 + for i := 0; i < len(s); i++ { + c := s[i] + // We're assuming Unix style text formats as input (LF line break), and + // quoted-printble uses CRLF line breaks. (Literal CRs will become + // "=0D", but probably shouldn't be there to begin with!) + if c == '\n' { + io.WriteString(w, "\r\n") + mc = 0 + continue + } + + var nextOut []byte + if isPrintable(c) { + nextOut = append(buf[:0], c) + } else { + nextOut = buf[:] + qpEscape(nextOut, c) + } + + // Add a soft line break if the next (encoded) byte would push this line + // to or past the limit. + if mc+len(nextOut) >= maxLineLength { + if _, err := io.WriteString(w, "=\r\n"); err != nil { + return err + } + mc = 0 + } + + if _, err := w.Write(nextOut); err != nil { + return err + } + mc += len(nextOut) + } + // No trailing end-of-line?? Soft line break, then. TODO: is this sane? + if mc > 0 { + io.WriteString(w, "=\r\n") + } + return nil +} + +// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise +func isPrintable(c byte) bool { + return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') +} + +// qpEscape is a helper function for quotePrintEncode which escapes a +// non-printable byte. Expects len(dest) == 3. +func qpEscape(dest []byte, c byte) { + const nums = "0123456789ABCDEF" + dest[0] = '=' + dest[1] = nums[(c&0xf0)>>4] + dest[2] = nums[(c & 0xf)] +} + +// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer +func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { + for k, v := range t { + // Write the header key + _, err := fmt.Fprintf(w, "%s:", k) + if err != nil { + return err + } + // Write each value in the header + for _, c := range v { + _, err := fmt.Fprintf(w, " %s\r\n", c) + if err != nil { + return err + } + } + } + return nil +} + +// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) +// The output is then written to the specified io.Writer +func base64Wrap(w io.Writer, b []byte) { + // 57 raw bytes per 76-byte base64 line. + const maxRaw = 57 + // Buffer for each line, including trailing CRLF. + var buffer [maxLineLength + len("\r\n")]byte + copy(buffer[maxLineLength:], "\r\n") + // Process raw chunks until there's no longer enough to fill a line. + for len(b) >= maxRaw { + base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) + w.Write(buffer[:]) + b = b[maxRaw:] + } + // Handle the last chunk of bytes. + if len(b) > 0 { + out := buffer[:base64.StdEncoding.EncodedLen(len(b))] + base64.StdEncoding.Encode(out, b) + out = append(out, "\r\n"...) + w.Write(out) + } +} + +// Encode returns the encoded-word form of s. If s is ASCII without special +// characters, it is returned unchanged. The provided charset is the IANA +// charset name of s. It is case insensitive. +// RFC 2047 encoded-word +func qEncode(charset, s string) string { + if !needsEncoding(s) { + return s + } + return encodeWord(charset, s) +} + +func needsEncoding(s string) bool { + for _, b := range s { + if (b < ' ' || b > '~') && b != '\t' { + return true + } + } + return false +} + +// encodeWord encodes a string into an encoded-word. +func encodeWord(charset, s string) string { + buf := getBuffer() + + buf.WriteString("=?") + buf.WriteString(charset) + buf.WriteByte('?') + buf.WriteByte('q') + buf.WriteByte('?') + + enc := make([]byte, 3) + for i := 0; i < len(s); i++ { + b := s[i] + switch { + case b == ' ': + buf.WriteByte('_') + case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_': + buf.WriteByte(b) + default: + enc[0] = '=' + enc[1] = upperhex[b>>4] + enc[2] = upperhex[b&0x0f] + buf.Write(enc) + } + } + buf.WriteString("?=") + + es := buf.String() + putBuffer(buf) + return es +} + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +func getBuffer() *bytes.Buffer { + return bufPool.Get().(*bytes.Buffer) +} + +func putBuffer(buf *bytes.Buffer) { + if buf.Len() > 1024 { + return + } + buf.Reset() + bufPool.Put(buf) +} diff --git a/pkg/utils/mail_test.go b/pkg/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/utils/pagination/controller.go b/pkg/utils/pagination/controller.go new file mode 100644 index 00000000..2f022d0c --- /dev/null +++ b/pkg/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/context" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { + paginator = NewPaginator(context.Request, per, nums) + context.Input.SetData("paginator", &paginator) + return +} diff --git a/pkg/utils/pagination/doc.go b/pkg/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/utils/pagination/paginator.go b/pkg/utils/pagination/paginator.go new file mode 100644 index 00000000..c6db31e0 --- /dev/null +++ b/pkg/utils/pagination/paginator.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "math" + "net/http" + "net/url" + "strconv" +) + +// Paginator within the state of a http request. +type Paginator struct { + Request *http.Request + PerPageNums int + MaxPages int + + nums int64 + pageRange []int + pageNums int + page int +} + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + if p.pageNums != 0 { + return p.pageNums + } + pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) + if p.MaxPages > 0 { + pageNums = math.Min(pageNums, float64(p.MaxPages)) + } + p.pageNums = int(pageNums) + return p.pageNums +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return p.nums +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + p.nums, _ = toInt64(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + if p.page != 0 { + return p.page + } + if p.Request.Form == nil { + p.Request.ParseForm() + } + p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) + if p.page > p.PageNums() { + p.page = p.PageNums() + } + if p.page <= 0 { + p.page = 1 + } + return p.page +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + if p.pageRange == nil && p.nums > 0 { + var pages []int + pageNums := p.PageNums() + page := p.Page() + switch { + case page >= pageNums-4 && pageNums > 9: + start := pageNums - 9 + 1 + pages = make([]int, 9) + for i := range pages { + pages[i] = start + i + } + case page >= 5 && pageNums > 9: + start := page - 5 + 1 + pages = make([]int, int(math.Min(9, float64(page+4+1)))) + for i := range pages { + pages[i] = start + i + } + default: + pages = make([]int, int(math.Min(9, float64(pageNums)))) + for i := range pages { + pages[i] = i + 1 + } + } + p.pageRange = pages + } + return p.pageRange +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + link, _ := url.ParseRequestURI(p.Request.URL.String()) + values := link.Query() + if page == 1 { + values.Del("p") + } else { + values.Set("p", strconv.Itoa(page)) + } + link.RawQuery = values.Encode() + return link.String() +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + if p.HasPrev() { + link = p.PageLink(p.Page() - 1) + } + return +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + if p.HasNext() { + link = p.PageLink(p.Page() + 1) + } + return +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return p.PageLink(1) +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return p.PageLink(p.PageNums()) +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return p.Page() > 1 +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return p.Page() < p.PageNums() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return p.Page() == page +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (p.Page() - 1) * p.PerPageNums +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return p.PageNums() > 1 +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + p := Paginator{} + p.Request = req + if per <= 0 { + per = 10 + } + p.PerPageNums = per + p.SetNums(nums) + return &p +} diff --git a/pkg/utils/pagination/utils.go b/pkg/utils/pagination/utils.go new file mode 100644 index 00000000..686e68b0 --- /dev/null +++ b/pkg/utils/pagination/utils.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "fmt" + "reflect" +) + +// ToInt64 convert any numeric value to int64 +func toInt64(value interface{}) (d int64, err error) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + err = fmt.Errorf("ToInt64 need numeric not `%T`", value) + } + return +} diff --git a/pkg/utils/rand.go b/pkg/utils/rand.go new file mode 100644 index 00000000..344d1cd5 --- /dev/null +++ b/pkg/utils/rand.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "crypto/rand" + r "math/rand" + "time" +) + +var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + if len(alphabets) == 0 { + alphabets = alphaNum + } + var bytes = make([]byte, n) + var randBy bool + if num, err := rand.Read(bytes); num != n || err != nil { + r.Seed(time.Now().UnixNano()) + randBy = true + } + for i, b := range bytes { + if randBy { + bytes[i] = alphabets[r.Intn(len(alphabets))] + } else { + bytes[i] = alphabets[b%byte(len(alphabets))] + } + } + return bytes +} diff --git a/pkg/utils/rand_test.go b/pkg/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/utils/safemap.go b/pkg/utils/safemap.go new file mode 100644 index 00000000..1793030a --- /dev/null +++ b/pkg/utils/safemap.go @@ -0,0 +1,91 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" +) + +// BeeMap is a map with lock +type BeeMap struct { + lock *sync.RWMutex + bm map[interface{}]interface{} +} + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return &BeeMap{ + lock: new(sync.RWMutex), + bm: make(map[interface{}]interface{}), + } +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + if val, ok := m.bm[k]; ok { + return val + } + return nil +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + m.lock.Lock() + defer m.lock.Unlock() + if val, ok := m.bm[k]; !ok { + m.bm[k] = v + } else if val != v { + m.bm[k] = v + } else { + return false + } + return true +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + m.lock.RLock() + defer m.lock.RUnlock() + _, ok := m.bm[k] + return ok +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.bm, k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + r := make(map[interface{}]interface{}) + for k, v := range m.bm { + r[k] = v + } + return r +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.bm) +} diff --git a/pkg/utils/safemap_test.go b/pkg/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/utils/slice.go b/pkg/utils/slice.go new file mode 100644 index 00000000..8f2cef98 --- /dev/null +++ b/pkg/utils/slice.go @@ -0,0 +1,170 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "math/rand" + "time" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + if max < min { + min, max = max, min + } + length := max - min + 1 + t0 := time.Now() + rand.Seed(int64(t0.Nanosecond())) + list := rand.Perm(length) + for index := range list { + list[index] += min + } + return list +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + c = append(slice1, slice2...) + return +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + for _, v := range slice { + dslice = append(dslice, a(v)) + } + return +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + randnum := rand.Intn(len(a)) + b = a[randnum] + return +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + for _, v := range intslice { + sum += v + } + return +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + for _, v := range slice { + if a(v) { + ftslice = append(ftslice, v) + } + } + return +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if !InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + if size >= len(slice) { + chunkslice = append(chunkslice, slice) + return + } + end := size + for i := 0; i <= (len(slice) - size); i += size { + chunkslice = append(chunkslice, slice[i:end]) + end += size + } + return +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + for i := start; i <= end; i += step { + intslice = append(intslice, i) + } + return +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + if size <= len(slice) { + return slice + } + for i := 0; i < (size - len(slice)); i++ { + slice = append(slice, val) + } + return slice +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + for _, v := range slice { + if !InSliceIface(v, uniqueslice) { + uniqueslice = append(uniqueslice, v) + } + } + return +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + for i := 0; i < len(slice); i++ { + a := rand.Intn(len(slice)) + b := rand.Intn(len(slice)) + slice[a], slice[b] = slice[b], slice[a] + } + return slice +} diff --git a/pkg/utils/slice_test.go b/pkg/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/utils/testdata/grepe.test b/pkg/utils/testdata/grepe.test new file mode 100644 index 00000000..6c014c40 --- /dev/null +++ b/pkg/utils/testdata/grepe.test @@ -0,0 +1,7 @@ +# empty lines + + + +hello +# comment +world diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 00000000..3874b803 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,89 @@ +package utils + +import ( + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + gopath := os.Getenv("GOPATH") + if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { + gopath = defaultGOPATH() + } + return filepath.SplitList(gopath) +} + +func compareGoVersion(a, b string) int { + reg := regexp.MustCompile("^\\d*") + + a = strings.TrimPrefix(a, "go") + b = strings.TrimPrefix(b, "go") + + versionsA := strings.Split(a, ".") + versionsB := strings.Split(b, ".") + + for i := 0; i < len(versionsA) && i < len(versionsB); i++ { + versionA := versionsA[i] + versionB := versionsB[i] + + vA, err := strconv.Atoi(versionA) + if err != nil { + str := reg.FindString(versionA) + if str != "" { + vA, _ = strconv.Atoi(str) + } else { + vA = -1 + } + } + + vB, err := strconv.Atoi(versionB) + if err != nil { + str := reg.FindString(versionB) + if str != "" { + vB, _ = strconv.Atoi(str) + } else { + vB = -1 + } + } + + if vA > vB { + // vA = 12, vB = 8 + return 1 + } else if vA < vB { + // vA = 6, vB = 8 + return -1 + } else if vA == -1 { + // vA = rc1, vB = rc3 + return strings.Compare(versionA, versionB) + } + + // vA = vB = 8 + continue + } + + if len(versionsA) > len(versionsB) { + return 1 + } else if len(versionsA) == len(versionsB) { + return 0 + } + + return -1 +} + +func defaultGOPATH() string { + env := "HOME" + if runtime.GOOS == "windows" { + env = "USERPROFILE" + } else if runtime.GOOS == "plan9" { + env = "home" + } + if home := os.Getenv(env); home != "" { + return filepath.Join(home, "go") + } + return "" +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 00000000..ced6f63f --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,36 @@ +package utils + +import ( + "testing" +) + +func TestCompareGoVersion(t *testing.T) { + targetVersion := "go1.8" + if compareGoVersion("go1.12.4", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8.7", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7.6", targetVersion) != -1 { + t.Error("should be -1") + } + + if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8rc1", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7rc1", targetVersion) != -1 { + t.Error("should be -1") + } +} diff --git a/pkg/validation/README.md b/pkg/validation/README.md new file mode 100644 index 00000000..43373e47 --- /dev/null +++ b/pkg/validation/README.md @@ -0,0 +1,147 @@ +validation +============== + +validation is a form validation for a data validation and error collecting using Go. + +## Installation and tests + +Install: + + go get github.com/astaxie/beego/validation + +Test: + + go test github.com/astaxie/beego/validation + +## Example + +Direct Use: + + import ( + "github.com/astaxie/beego/validation" + "log" + ) + + type User struct { + Name string + Age int + } + + func main() { + u := User{"man", 40} + valid := validation.Validation{} + valid.Required(u.Name, "name") + valid.MaxSize(u.Name, 15, "nameMax") + valid.Range(u.Age, 0, 140, "age") + if valid.HasErrors() { + // validation does not pass + // print invalid message + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + } + // or use like this + if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { + log.Println(v.Error.Key, v.Error.Message) + } + } + +Struct Tag Use: + + import ( + "github.com/astaxie/beego/validation" + ) + + // validation function follow with "valid" tag + // functions divide with ";" + // parameters in parentheses "()" and divide with "," + // Match function's pattern string must in "//" + type user struct { + Id int + Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + func main() { + valid := validation.Validation{} + // ignore empty field valid + // see CanSkipFuncs + // valid := validation.Validation{RequiredFirst:true} + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Use custom function: + + import ( + "github.com/astaxie/beego/validation" + ) + + type user struct { + Id int + Name string `valid:"Required;IsMe"` + Age int `valid:"Required;Range(1, 140)"` + } + + func IsMe(v *validation.Validation, obj interface{}, key string) { + name, ok:= obj.(string) + if !ok { + // wrong use case? + return + } + + if name != "me" { + // valid false + v.SetError("Name", "is not me!") + } + } + + func main() { + valid := validation.Validation{} + if err := validation.AddCustomFunc("IsMe", IsMe); err != nil { + // hadle error + } + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Struct Tag Functions: + + Required + Min(min int) + Max(max int) + Range(min, max int) + MinSize(min int) + MaxSize(max int) + Length(length int) + Alpha + Numeric + AlphaNumeric + Match(pattern string) + AlphaDash + Email + IP + Base64 + Mobile + Tel + Phone + ZipCode + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/validation/util.go b/pkg/validation/util.go new file mode 100644 index 00000000..82206f4f --- /dev/null +++ b/pkg/validation/util.go @@ -0,0 +1,298 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +const ( + // ValidTag struct tag + ValidTag = "valid" + + LabelTag = "label" + + wordsize = 32 << (^uint(0) >> 32 & 1) +) + +var ( + // key: function name + // value: the number of parameters + funcs = make(Funcs) + + // doesn't belong to validation functions + unFuncs = map[string]bool{ + "Clear": true, + "HasErrors": true, + "ErrorMap": true, + "Error": true, + "apply": true, + "Check": true, + "Valid": true, + "NoMatch": true, + } + // ErrInt64On32 show 32 bit platform not support int64 + ErrInt64On32 = fmt.Errorf("not support int64 on 32-bit platform") +) + +func init() { + v := &Validation{} + t := reflect.TypeOf(v) + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + if !unFuncs[m.Name] { + funcs[m.Name] = m.Func + } + } +} + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + if unFuncs[name] { + return fmt.Errorf("invalid function name: %s", name) + } + + funcs[name] = reflect.ValueOf(f) + return nil +} + +// ValidFunc Valid function type +type ValidFunc struct { + Name string + Params []interface{} +} + +// Funcs Validate function map +type Funcs map[string]reflect.Value + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + if _, ok := f[name]; !ok { + err = fmt.Errorf("%s does not exist", name) + return + } + if len(params) != f[name].Type().NumIn() { + err = fmt.Errorf("The number of params is not adapted") + return + } + in := make([]reflect.Value, len(params)) + for k, param := range params { + in[k] = reflect.ValueOf(param) + } + result = f[name].Call(in) + return +} + +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { + tag := f.Tag.Get(ValidTag) + label := f.Tag.Get(LabelTag) + if len(tag) == 0 { + return + } + if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil { + return + } + fs := strings.Split(tag, ";") + for _, vfunc := range fs { + var vf ValidFunc + if len(vfunc) == 0 { + continue + } + vf, err = parseFunc(vfunc, f.Name, label) + if err != nil { + return + } + vfs = append(vfs, vf) + } + return +} + +// Get Match function +// May be get NoMatch function in the future +func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { + tag = strings.TrimSpace(tag) + index := strings.Index(tag, "Match(/") + if index == -1 { + str = tag + return + } + end := strings.LastIndex(tag, "/)") + if end < index { + err = fmt.Errorf("invalid Match function") + return + } + reg, err := regexp.Compile(tag[index+len("Match(/") : end]) + if err != nil { + return + } + vfs = []ValidFunc{{"Match", []interface{}{reg, key + ".Match"}}} + str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):]) + return +} + +func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + vfunc = strings.TrimSpace(vfunc) + start := strings.Index(vfunc, "(") + var num int + + // doesn't need parameter valid function + if start == -1 { + if num, err = numIn(vfunc); err != nil { + return + } + if num != 0 { + err = fmt.Errorf("%s require %d parameters", vfunc, num) + return + } + v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} + return + } + + end := strings.Index(vfunc, ")") + if end == -1 { + err = fmt.Errorf("invalid valid function") + return + } + + name := strings.TrimSpace(vfunc[:start]) + if num, err = numIn(name); err != nil { + return + } + + params := strings.Split(vfunc[start+1:end], ",") + // the num of param must be equal + if num != len(params) { + err = fmt.Errorf("%s require %d parameters", name, num) + return + } + + tParams, err := trim(name, key+"."+ name + "." + label, params) + if err != nil { + return + } + v = ValidFunc{name, tParams} + return +} + +func numIn(name string) (num int, err error) { + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + // sub *Validation obj and key + num = fn.Type().NumIn() - 3 + return +} + +func trim(name, key string, s []string) (ts []interface{}, err error) { + ts = make([]interface{}, len(s), len(s)+1) + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + for i := 0; i < len(s); i++ { + var param interface{} + // skip *Validation and obj params + if param, err = parseParam(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil { + return + } + ts[i] = param + } + ts = append(ts, key) + return +} + +// modify the parameters's type to adapt the function input parameters' type +func parseParam(t reflect.Type, s string) (i interface{}, err error) { + switch t.Kind() { + case reflect.Int: + i, err = strconv.Atoi(s) + case reflect.Int64: + if wordsize == 32 { + return nil, ErrInt64On32 + } + i, err = strconv.ParseInt(s, 10, 64) + case reflect.Int32: + var v int64 + v, err = strconv.ParseInt(s, 10, 32) + if err == nil { + i = int32(v) + } + case reflect.Int16: + var v int64 + v, err = strconv.ParseInt(s, 10, 16) + if err == nil { + i = int16(v) + } + case reflect.Int8: + var v int64 + v, err = strconv.ParseInt(s, 10, 8) + if err == nil { + i = int8(v) + } + case reflect.String: + i = s + case reflect.Ptr: + if t.Elem().String() != "regexp.Regexp" { + err = fmt.Errorf("not support %s", t.Elem().String()) + return + } + i, err = regexp.Compile(s) + default: + err = fmt.Errorf("not support %s", t.Kind().String()) + } + return +} + +func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} { + return append([]interface{}{v, obj}, params...) +} diff --git a/pkg/validation/util_test.go b/pkg/validation/util_test.go new file mode 100644 index 00000000..58ca38db --- /dev/null +++ b/pkg/validation/util_test.go @@ -0,0 +1,128 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "log" + "reflect" + "testing" +) + +type user struct { + ID int + Tag string `valid:"Maxx(aa)"` + Name string `valid:"Required;"` + Age int `valid:"Required; Range(1, 140)"` + match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` +} + +func TestGetValidFuncs(t *testing.T) { + u := user{Name: "test", Age: 1} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + + f, _ := tf.FieldByName("ID") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 0 { + t.Fatal("should get none ValidFunc") + } + + f, _ = tf.FieldByName("Tag") + if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { + t.Fatal(err) + } + + f, _ = tf.FieldByName("Name") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 1 { + t.Fatal("should get 1 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + + f, _ = tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 2 { + t.Fatal("should get 2 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 { + t.Error("Range funcs should be got") + } + + f, _ = tf.FieldByName("match") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 3 { + t.Fatal("should get 3 ValidFunc but now is", len(vfs)) + } +} + +type User struct { + Name string `valid:"Required;MaxSize(5)" ` + Sex string `valid:"Required;" label:"sex_label"` + Age int `valid:"Required;Range(1, 140);" label:"age_label"` +} + +func TestValidation(t *testing.T) { + u := User{"man1238888456", "", 1140} + valid := Validation{} + b, err := valid.Valid(&u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + if len(valid.Errors) != 3 { + t.Error("must be has 3 error") + } + } else { + t.Error("must be has 3 error") + } +} + +func TestCall(t *testing.T) { + u := user{Name: "test", Age: 180} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + f, _ := tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + valid := &Validation{} + vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...) + if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil { + t.Fatal(err) + } + if len(valid.Errors) != 1 { + t.Error("age out of range should be has an error") + } +} diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go new file mode 100644 index 00000000..190e0f0e --- /dev/null +++ b/pkg/validation/validation.go @@ -0,0 +1,456 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error struct { + Message, Key, Name, Field, Tmpl string + Value interface{} + LimitValue interface{} +} + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result struct { + Error *Error + Ok bool +} + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation struct { + // if this field set true, in struct tag valid + // if the struct field vale is empty + // it will skip those valid functions, see CanSkipFuncs + RequiredFirst bool + + Errors []*Error + ErrorsMap map[string][]*Error +} + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + v.Errors = []*Error{} + v.ErrorsMap = nil +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return len(v.Errors) > 0 +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + return v.ErrorsMap +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + result := (&Result{ + Ok: false, + Error: &Error{}, + }).Message(message, args...) + v.Errors = append(v.Errors, result.Error) + return result +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return v.apply(Required{key}, obj) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return v.apply(Min{min, key}, obj) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return v.apply(Max{max, key}, obj) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return v.apply(MinSize{min, key}, obj) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return v.apply(MaxSize{max, key}, obj) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return v.apply(Length{n, key}, obj) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return v.apply(Alpha{key}, obj) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return v.apply(Numeric{key}, obj) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return v.apply(AlphaNumeric{key}, obj) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(Match{regex, key}, obj) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, + Tel{Match: Match{Regexp: telPattern}}, key}, obj) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) +} + +func (v *Validation) apply(chk Validator, obj interface{}) *Result { + if nil == obj { + if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { + if reflect.ValueOf(obj).IsNil() { + if chk.IsSatisfied(nil) { + return &Result{Ok: true} + } + } else { + if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { + return &Result{Ok: true} + } + } + } else if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + + // Add the error to the validation context. + key := chk.GetKey() + Name := key + Field := "" + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + chk.DefaultMessage(), + Key: key, + Name: Name, + Field: Field, + Value: obj, + Tmpl: MessageTmpls[Name], + LimitValue: chk.GetLimitValue(), + } + v.setError(err) + + // Also return it in the result. + return &Result{ + Ok: false, + Error: err, + } +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + Name := key + Field := "" + + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + message, + Key: key, + Name: Name, + Field: Field, + } + v.setError(err) +} + +func (v *Validation) setError(err *Error) { + v.Errors = append(v.Errors, err) + if v.ErrorsMap == nil { + v.ErrorsMap = make(map[string][]*Error) + } + if _, ok := v.ErrorsMap[err.Field]; !ok { + v.ErrorsMap[err.Field] = []*Error{} + } + v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + err := &Error{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} + v.setError(err) + return err +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + var result *Result + for _, check := range checks { + result = v.apply(check, obj) + if !result.Ok { + return result + } + } + return result +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + switch { + case isStruct(objT): + case isStructPtr(objT): + objT = objT.Elem() + objV = objV.Elem() + default: + err = fmt.Errorf("%v must be a struct or a struct pointer", obj) + return + } + + for i := 0; i < objT.NumField(); i++ { + var vfs []ValidFunc + if vfs, err = getValidFuncs(objT.Field(i)); err != nil { + return + } + + var hasRequired bool + for _, vf := range vfs { + if vf.Name == "Required" { + hasRequired = true + } + + currentField := objV.Field(i).Interface() + if objV.Field(i).Kind() == reflect.Ptr { + if objV.Field(i).IsNil() { + currentField = "" + } else { + currentField = objV.Field(i).Elem().Interface() + } + } + + chk := Required{""}.IsSatisfied(currentField) + if !hasRequired && v.RequiredFirst && !chk { + if _, ok := CanSkipFuncs[vf.Name]; ok { + continue + } + } + + if _, err = funcs.Call(vf.Name, + mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil { + return + } + } + } + + if !v.HasErrors() { + if form, ok := obj.(ValidFormer); ok { + form.Valid(v) + } + } + + return !v.HasErrors(), nil +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + //Step 1: validate obj itself firstly + // fails if objc is not struct + pass, err := v.Valid(objc) + if err != nil || !pass { + return pass, err // Stop recursive validation + } + // Step 2: Validate struct's struct fields + objT := reflect.TypeOf(objc) + objV := reflect.ValueOf(objc) + + if isStructPtr(objT) { + objT = objT.Elem() + objV = objV.Elem() + } + + for i := 0; i < objT.NumField(); i++ { + + t := objT.Field(i).Type + + // Recursive applies to struct or pointer to structs fields + if isStruct(t) || isStructPtr(t) { + // Step 3: do the recursive validation + // Only valid the Public field recursively + if objV.Field(i).CanInterface() { + pass, err = v.RecursiveValid(objV.Field(i).Interface()) + } + } + } + return pass, err +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + if _, ok := CanSkipFuncs[skipFunc]; !ok { + CanSkipFuncs[skipFunc] = struct{}{} + } +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/validation/validators.go b/pkg/validation/validators.go new file mode 100644 index 00000000..38b6f1aa --- /dev/null +++ b/pkg/validation/validators.go @@ -0,0 +1,738 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "github.com/astaxie/beego/logs" + "reflect" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = map[string]struct{}{ + "Email": {}, + "IP": {}, + "Mobile": {}, + "Tel": {}, + "Phone": {}, + "ZipCode": {}, +} + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + if len(msg) == 0 { + return + } + + once.Do(func() { + for name := range msg { + MessageTmpls[name] = msg[name] + } + }) + logs.Warn(`you must SetDefaultMessage at once`) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required struct { + Key string +} + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + if obj == nil { + return false + } + + if str, ok := obj.(string); ok { + return len(strings.TrimSpace(str)) > 0 + } + if _, ok := obj.(bool); ok { + return true + } + if i, ok := obj.(int); ok { + return i != 0 + } + if i, ok := obj.(uint); ok { + return i != 0 + } + if i, ok := obj.(int8); ok { + return i != 0 + } + if i, ok := obj.(uint8); ok { + return i != 0 + } + if i, ok := obj.(int16); ok { + return i != 0 + } + if i, ok := obj.(uint16); ok { + return i != 0 + } + if i, ok := obj.(uint32); ok { + return i != 0 + } + if i, ok := obj.(int32); ok { + return i != 0 + } + if i, ok := obj.(int64); ok { + return i != 0 + } + if i, ok := obj.(uint64); ok { + return i != 0 + } + if t, ok := obj.(time.Time); ok { + return !t.IsZero() + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() > 0 + } + return true +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return MessageTmpls["Required"] +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return r.Key +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return nil +} + +// Min check struct +type Min struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v >= m.Min +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Min"], m.Min) +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return m.Min +} + +// Max validate struct +type Max struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v <= m.Max +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Max"], m.Max) +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return m.Max +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range struct { + Min + Max + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Range"], r.Min.Min, r.Max.Max) +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return r.Key +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return []int{r.Min.Min, r.Max.Max} +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) >= m.Min + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() >= m.Min + } + return false +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MinSize"], m.Min) +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return m.Min +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) <= m.Max + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() <= m.Max + } + return false +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MaxSize"], m.Max) +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return m.Max +} + +// Length Requires an array or string to be exactly a given length. +type Length struct { + N int + Key string +} + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) == l.N + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() == l.N + } + return false +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Length"], l.N) +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return l.Key +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return l.N +} + +// Alpha check the alpha +type Alpha struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return MessageTmpls["Alpha"] +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return nil +} + +// Numeric check number +type Numeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if '9' < v || v < '0' { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return MessageTmpls["Numeric"] +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return nil +} + +// AlphaNumeric check alpha and number +type AlphaNumeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return MessageTmpls["AlphaNumeric"] +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return nil +} + +// Match Requires a string to match a given regex. +type Match struct { + Regexp *regexp.Regexp + Key string +} + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return m.Regexp.MatchString(fmt.Sprintf("%v", obj)) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Match"], m.Regexp.String()) +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return m.Regexp.String() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch struct { + Match + Key string +} + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return !n.Match.IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["NoMatch"], n.Regexp.String()) +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return n.Regexp.String() +} + +var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) + +// AlphaDash check not Alpha +type AlphaDash struct { + NoMatch + Key string +} + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return MessageTmpls["AlphaDash"] +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return nil +} + +var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) + +// Email check struct +type Email struct { + Match + Key string +} + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return MessageTmpls["Email"] +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return e.Key +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return nil +} + +var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) + +// IP check struct +type IP struct { + Match + Key string +} + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return MessageTmpls["IP"] +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return i.Key +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return nil +} + +var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) + +// Base64 check struct +type Base64 struct { + Match + Key string +} + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return MessageTmpls["Base64"] +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return b.Key +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return nil +} + +// just for chinese mobile phone number +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) + +// Mobile check struct +type Mobile struct { + Match + Key string +} + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return MessageTmpls["Mobile"] +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return nil +} + +// just for chinese telephone number +var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) + +// Tel check telephone struct +type Tel struct { + Match + Key string +} + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return MessageTmpls["Tel"] +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return t.Key +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return nil +} + +// Phone just for chinese telephone or mobile phone number +type Phone struct { + Mobile + Tel + Key string +} + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return MessageTmpls["Phone"] +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return p.Key +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return nil +} + +// just for chinese zipcode +var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) + +// ZipCode check the zip struct +type ZipCode struct { + Match + Key string +} + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return MessageTmpls["ZipCode"] +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return z.Key +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return nil +} From 30eb889a91f58189ac0b6d059031ee66e556d966 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 23:00:06 +0800 Subject: [PATCH 037/207] Format code --- build_info.go | 8 +++--- cache/redis/redis.go | 2 +- config/yaml/yaml.go | 2 +- logs/accesslog.go | 2 +- logs/file.go | 30 +++++++++++----------- logs/file_test.go | 30 +++++++++++----------- metric/prometheus.go | 6 ++--- orm/cmd_utils.go | 6 ++--- orm/db_alias.go | 2 +- orm/orm_log.go | 2 +- pkg/build_info.go | 8 +++--- pkg/cache/redis/redis.go | 2 +- pkg/common/kv_test.go | 2 +- pkg/config/yaml/yaml.go | 2 +- pkg/logs/accesslog.go | 2 +- pkg/logs/file.go | 30 +++++++++++----------- pkg/logs/file_test.go | 30 +++++++++++----------- pkg/metric/prometheus.go | 6 ++--- pkg/orm/cmd_utils.go | 6 ++--- pkg/orm/db_alias.go | 6 ++--- pkg/orm/models_test.go | 6 ++--- pkg/orm/orm_log.go | 2 +- pkg/orm/types.go | 6 ++--- pkg/session/redis_cluster/redis_cluster.go | 17 ++++++------ pkg/session/sess_file_test.go | 5 ++-- pkg/staticfile.go | 2 +- pkg/templatefunc.go | 2 +- pkg/toolbox/task.go | 2 +- pkg/toolbox/task_test.go | 4 +-- pkg/validation/util.go | 2 +- session/redis_cluster/redis_cluster.go | 17 ++++++------ session/sess_file_test.go | 5 ++-- staticfile.go | 2 +- templatefunc.go | 2 +- toolbox/task.go | 2 +- toolbox/task_test.go | 4 +-- validation/util.go | 2 +- 37 files changed, 133 insertions(+), 133 deletions(-) diff --git a/build_info.go b/build_info.go index 6dc2835e..c31152ea 100644 --- a/build_info.go +++ b/build_info.go @@ -15,11 +15,11 @@ package beego var ( - BuildVersion string + BuildVersion string BuildGitRevision string - BuildStatus string - BuildTag string - BuildTime string + BuildStatus string + BuildTag string + BuildTime string GoVersion string diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 56faf211..d8737b3c 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -57,7 +57,7 @@ type Cache struct { maxIdle int //the timeout to a value less than the redis server's timeout. - timeout time.Duration + timeout time.Duration } // NewRedisCache create new redis cache with default collection name. diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index 5def2da3..a5644c7b 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -296,7 +296,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { case map[string]interface{}: { tmpData = v.(map[string]interface{}) - if idx == len(keys) - 1 { + if idx == len(keys)-1 { return tmpData, nil } } diff --git a/logs/accesslog.go b/logs/accesslog.go index 3ff9e20f..9011b602 100644 --- a/logs/accesslog.go +++ b/logs/accesslog.go @@ -16,9 +16,9 @@ package logs import ( "bytes" - "strings" "encoding/json" "fmt" + "strings" "time" ) diff --git a/logs/file.go b/logs/file.go index 222db989..40a3572a 100644 --- a/logs/file.go +++ b/logs/file.go @@ -373,21 +373,21 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } return }) } diff --git a/logs/file_test.go b/logs/file_test.go index e7c2ca9a..385eac43 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -186,7 +186,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode func TestFileHourlyRotate_01(t *testing.T) { log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -237,7 +237,7 @@ func TestFileHourlyRotate_05(t *testing.T) { func TestFileHourlyRotate_06(t *testing.T) { //test file mode log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -269,19 +269,19 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { RotatePerm: "0440", } - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) @@ -328,8 +328,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, + Hourly: true, + MaxHours: 168, Rotate: true, Level: LevelTrace, Perm: "0660", diff --git a/metric/prometheus.go b/metric/prometheus.go index 7722240b..86e2c1b1 100644 --- a/metric/prometheus.go +++ b/metric/prometheus.go @@ -57,15 +57,15 @@ func registerBuildInfo() { Subsystem: "build_info", Help: "The building information", ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, + "appname": beego.BConfig.AppName, "build_version": beego.BuildVersion, "build_revision": beego.BuildGitRevision, "build_status": beego.BuildStatus, "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), "go_version": beego.GoVersion, "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), + "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 61f17346..692a079f 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index bf6c350c..fe6abeb5 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -424,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } diff --git a/orm/orm_log.go b/orm/orm_log.go index f107bb59..5bb3a24f 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -61,7 +61,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error con += " - " + err.Error() } logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil{ + if LogFunc != nil { LogFunc(logMap) } DebugLog.Println(con) diff --git a/pkg/build_info.go b/pkg/build_info.go index 6dc2835e..c31152ea 100644 --- a/pkg/build_info.go +++ b/pkg/build_info.go @@ -15,11 +15,11 @@ package beego var ( - BuildVersion string + BuildVersion string BuildGitRevision string - BuildStatus string - BuildTag string - BuildTime string + BuildStatus string + BuildTag string + BuildTime string GoVersion string diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go index 56faf211..d8737b3c 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/cache/redis/redis.go @@ -57,7 +57,7 @@ type Cache struct { maxIdle int //the timeout to a value less than the redis server's timeout. - timeout time.Duration + timeout time.Duration } // NewRedisCache create new redis cache with default collection name. diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index ed7dc7ef..45adf5ff 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -23,7 +23,7 @@ import ( func TestKVs(t *testing.T) { key := "my-key" kvs := NewKVs(KV{ - Key: key, + Key: key, Value: 12, }) diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go index 5def2da3..a5644c7b 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/config/yaml/yaml.go @@ -296,7 +296,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { case map[string]interface{}: { tmpData = v.(map[string]interface{}) - if idx == len(keys) - 1 { + if idx == len(keys)-1 { return tmpData, nil } } diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go index 3ff9e20f..9011b602 100644 --- a/pkg/logs/accesslog.go +++ b/pkg/logs/accesslog.go @@ -16,9 +16,9 @@ package logs import ( "bytes" - "strings" "encoding/json" "fmt" + "strings" "time" ) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 222db989..40a3572a 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -373,21 +373,21 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } return }) } diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go index e7c2ca9a..385eac43 100644 --- a/pkg/logs/file_test.go +++ b/pkg/logs/file_test.go @@ -186,7 +186,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode func TestFileHourlyRotate_01(t *testing.T) { log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -237,7 +237,7 @@ func TestFileHourlyRotate_05(t *testing.T) { func TestFileHourlyRotate_06(t *testing.T) { //test file mode log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -269,19 +269,19 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { RotatePerm: "0440", } - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) @@ -328,8 +328,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, + Hourly: true, + MaxHours: 168, Rotate: true, Level: LevelTrace, Perm: "0660", diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go index 7722240b..86e2c1b1 100644 --- a/pkg/metric/prometheus.go +++ b/pkg/metric/prometheus.go @@ -57,15 +57,15 @@ func registerBuildInfo() { Subsystem: "build_info", Help: "The building information", ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, + "appname": beego.BConfig.AppName, "build_version": beego.BuildVersion, "build_revision": beego.BuildGitRevision, "build_status": beego.BuildStatus, "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), "go_version": beego.GoVersion, "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), + "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go index 61f17346..692a079f 100644 --- a/pkg/orm/cmd_utils.go +++ b/pkg/orm/cmd_utils.go @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 90c5de3c..a3f2a0b9 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -244,7 +244,7 @@ var _ dbQuerier = new(TxDB) var _ txEnder = new(TxDB) func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { - return t.PrepareContext(context.Background(),query) + return t.PrepareContext(context.Background(), query) } func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { @@ -260,7 +260,7 @@ func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{ } func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { - return t.QueryContext(context.Background(),query,args...) + return t.QueryContext(context.Background(), query, args...) } func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -268,7 +268,7 @@ func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface } func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { - return t.QueryRowContext(context.Background(),query,args...) + return t.QueryRowContext(context.Background(), query, args...) } func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index f14ee9cf..4c00050d 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -490,11 +490,11 @@ func init() { } err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ - Key:MaxIdleConnsKey, - Value:20, + Key: MaxIdleConnsKey, + Value: 20, }) - if err != nil{ + if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) } diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go index f107bb59..5bb3a24f 100644 --- a/pkg/orm/orm_log.go +++ b/pkg/orm/orm_log.go @@ -61,7 +61,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error con += " - " + err.Error() } logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil{ + if LogFunc != nil { LogFunc(logMap) } DebugLog.Println(con) diff --git a/pkg/orm/types.go b/pkg/orm/types.go index b7a38826..8255d93e 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -110,7 +110,7 @@ type DQL interface { // Like Read(), but with "FOR UPDATE" clause, useful in transaction. // Some databases are not support this feature. - ReadForUpdate( md interface{}, cols ...string) error + ReadForUpdate(md interface{}, cols ...string) error ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error // Try to read a row from the database, or insert one if it doesn't exist @@ -129,14 +129,14 @@ type DQL interface { // args[2] int offset default offset 0 // args[3] string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated( md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) // create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M( md interface{}, name string) QueryM2Mer + QueryM2M(md interface{}, name string) QueryM2Mer QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index 2fe300df..262fa2e3 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -31,13 +31,14 @@ // // more docs: http://beego.me/docs/module/session.md package redis_cluster + import ( + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" "time" ) @@ -101,7 +102,7 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { return } c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) } // Provider redis_cluster session provider @@ -146,10 +147,10 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, + Password: rp.password, PoolSize: rp.poolsize, }) return rp.poollist.Ping().Err() @@ -186,15 +187,15 @@ func (rp *Provider) SessionExist(sid string) bool { // SessionRegenerate generate new sid for redis_cluster session func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := rp.poollist - + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { // oldsid doesn't exists, set the new sid directly // ignore error here, since if it return error // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } return rp.SessionRead(sid) } diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go index 0cf021db..021c43fc 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/session/sess_file_test.go @@ -369,8 +369,7 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error(err) } - - s.Set(i,i) + s.Set(i, i) s.SessionRelease(nil) } @@ -384,4 +383,4 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error() } } -} \ No newline at end of file +} diff --git a/pkg/staticfile.go b/pkg/staticfile.go index 84e9aa7b..e26776c5 100644 --- a/pkg/staticfile.go +++ b/pkg/staticfile.go @@ -202,7 +202,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/pkg/templatefunc.go b/pkg/templatefunc.go index ba1ec5eb..6f02b8d6 100644 --- a/pkg/templatefunc.go +++ b/pkg/templatefunc.go @@ -362,7 +362,7 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { if strings.Contains(value, "T") { value = value[:19] diff --git a/pkg/toolbox/task.go b/pkg/toolbox/task.go index c902fdfc..fb2c5f16 100644 --- a/pkg/toolbox/task.go +++ b/pkg/toolbox/task.go @@ -113,7 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func diff --git a/pkg/toolbox/task_test.go b/pkg/toolbox/task_test.go index 3a4cce2f..b63f4391 100644 --- a/pkg/toolbox/task_test.go +++ b/pkg/toolbox/task_test.go @@ -59,12 +59,12 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 task := func() error { - cnt ++ + cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) } tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200 ; i ++ { + for i := 0; i < 200; i++ { e := tk.Run() assert.NotNil(t, e) } diff --git a/pkg/validation/util.go b/pkg/validation/util.go index 82206f4f..918b206c 100644 --- a/pkg/validation/util.go +++ b/pkg/validation/util.go @@ -213,7 +213,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { return } - tParams, err := trim(name, key+"."+ name + "." + label, params) + tParams, err := trim(name, key+"."+name+"."+label, params) if err != nil { return } diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go index 2fe300df..262fa2e3 100644 --- a/session/redis_cluster/redis_cluster.go +++ b/session/redis_cluster/redis_cluster.go @@ -31,13 +31,14 @@ // // more docs: http://beego.me/docs/module/session.md package redis_cluster + import ( + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" "time" ) @@ -101,7 +102,7 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { return } c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) } // Provider redis_cluster session provider @@ -146,10 +147,10 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, + Password: rp.password, PoolSize: rp.poolsize, }) return rp.poollist.Ping().Err() @@ -186,15 +187,15 @@ func (rp *Provider) SessionExist(sid string) bool { // SessionRegenerate generate new sid for redis_cluster session func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := rp.poollist - + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { // oldsid doesn't exists, set the new sid directly // ignore error here, since if it return error // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } return rp.SessionRead(sid) } diff --git a/session/sess_file_test.go b/session/sess_file_test.go index 0cf021db..021c43fc 100644 --- a/session/sess_file_test.go +++ b/session/sess_file_test.go @@ -369,8 +369,7 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error(err) } - - s.Set(i,i) + s.Set(i, i) s.SessionRelease(nil) } @@ -384,4 +383,4 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error() } } -} \ No newline at end of file +} diff --git a/staticfile.go b/staticfile.go index 84e9aa7b..e26776c5 100644 --- a/staticfile.go +++ b/staticfile.go @@ -202,7 +202,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/templatefunc.go b/templatefunc.go index ba1ec5eb..6f02b8d6 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -362,7 +362,7 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { if strings.Contains(value, "T") { value = value[:19] diff --git a/toolbox/task.go b/toolbox/task.go index c902fdfc..fb2c5f16 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -113,7 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func diff --git a/toolbox/task_test.go b/toolbox/task_test.go index 3a4cce2f..b63f4391 100644 --- a/toolbox/task_test.go +++ b/toolbox/task_test.go @@ -59,12 +59,12 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 task := func() error { - cnt ++ + cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) } tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200 ; i ++ { + for i := 0; i < 200; i++ { e := tk.Run() assert.NotNil(t, e) } diff --git a/validation/util.go b/validation/util.go index 82206f4f..918b206c 100644 --- a/validation/util.go +++ b/validation/util.go @@ -213,7 +213,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { return } - tParams, err := trim(name, key+"."+ name + "." + label, params) + tParams, err := trim(name, key+"."+name+"."+label, params) if err != nil { return } From 79c2157ad47c392ee780f71f42535717444de08b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 15:34:55 +0000 Subject: [PATCH 038/207] Fix UT --- orm/models_test.go | 497 ++++ orm/orm_test.go | 2500 +++++++++++++++++ orm/utils_test.go | 70 + pkg/LICENSE | 13 + test.sh => scripts/test.sh | 2 +- .../test_docker_compose.yaml | 0 6 files changed, 3081 insertions(+), 1 deletion(-) create mode 100644 orm/models_test.go create mode 100644 orm/orm_test.go create mode 100644 orm/utils_test.go create mode 100644 pkg/LICENSE rename test.sh => scripts/test.sh (94%) rename test_docker_compose.yaml => scripts/test_docker_compose.yaml (100%) diff --git a/orm/models_test.go b/orm/models_test.go new file mode 100644 index 00000000..e3a635f2 --- /dev/null +++ b/orm/models_test.go @@ -0,0 +1,497 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" +) + +// A slice string field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeVarCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(SliceStringField) + +// A json field. +type JSONFieldTest struct { + Name string + Data string +} + +func (e *JSONFieldTest) String() string { + data, _ := json.Marshal(e) + return string(data) +} + +func (e *JSONFieldTest) FieldType() int { + return TypeTextField +} + +func (e *JSONFieldTest) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + return json.Unmarshal([]byte(d), e) + default: + return fmt.Errorf(" unknown value `%v`", value) + } +} + +func (e *JSONFieldTest) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(JSONFieldTest) + +type Data struct { + ID int `orm:"column(id)"` + Boolean bool + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` + Date time.Time `orm:"type(date)"` + DateTime time.Time `orm:"column(datetime)"` + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + ID int `orm:"column(id)"` + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + +// only for mysql +type UserBig struct { + ID uint64 `orm:"column(id)"` + Name string +} + +type User struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool +} + +func (u *User) TableIndex() [][]string { + return [][]string{ + {"Id", "UserName"}, + {"Id", "Created"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + {"UserName", "Email"}, + } +} + +func NewUser() *User { + obj := new(User) + return obj +} + +type Profile struct { + ID int `orm:"column(id)"` + Age int16 + Money float64 + User *User `orm:"reverse(one)" json:"-"` + BestPost *Post `orm:"rel(one);null"` +} + +func (u *Profile) TableName() string { + return "user_profile" +} + +func NewProfile() *Profile { + obj := new(Profile) + return obj +} + +type Post struct { + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` +} + +func (u *Post) TableIndex() [][]string { + return [][]string{ + {"Id", "Created"}, + } +} + +func NewPost() *Post { + obj := new(Post) + return obj +} + +type Tag struct { + ID int `orm:"column(id)"` + Name string `orm:"size(30)"` + BestPost *Post `orm:"rel(one);null"` + Posts []*Post `orm:"reverse(many)" json:"-"` +} + +func NewTag() *Tag { + obj := new(Tag) + return obj +} + +type PostTags struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk)"` + Tag *Tag `orm:"rel(fk)"` +} + +func (m *PostTags) TableName() string { + return "prefix_post_tags" +} + +type Comment struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk);column(post)"` + Content string `orm:"type(text)"` + Parent *Comment `orm:"null;rel(fk)"` + Created time.Time `orm:"auto_now_add"` +} + +func NewComment() *Comment { + obj := new(Comment) + return obj +} + +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + +type ModelID struct { + ID int64 +} + +type ModelBase struct { + ModelID + + Created time.Time `orm:"auto_now_add;type(datetime)"` + Updated time.Time `orm:"auto_now;type(datetime)"` +} + +type InLine struct { + // Common Fields + ModelBase + + // Other Fields + Name string `orm:"unique"` + Email string +} + +func NewInLine() *InLine { + return new(InLine) +} + +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" +) + +var ( + dORM Ormer + dDbBaser dbBaser +) + +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + +func init() { + Debug, _ = StrTo(DBARGS.Debug).Bool() + + if DBARGS.Driver == "" || DBARGS.Source == "" { + fmt.Println(helpinfo) + os.Exit(2) + } + + RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DRMySQL { + alias.Engine = "INNODB" + } + +} diff --git a/orm/orm_test.go b/orm/orm_test.go new file mode 100644 index 00000000..eac7b33a --- /dev/null +++ b/orm/orm_test.go @@ -0,0 +1,2500 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package orm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var _ = os.PathSeparator + +var ( + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" +) + +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { + if len(args) == 0 { + return false, fmt.Errorf("miss args") + } + b := args[0] + arg := argAny(args) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b + } else { + err = fmt.Errorf("compare datetime miss format") + goto wrongArg + } + } + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } else { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } + } + +wrongArg: + if err != nil { + return false, err + } + + return true, nil +} + +func AssertIs(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(true, a, args...); !ok { + return err + } + return nil +} + +func AssertNot(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(false, a, args...); !ok { + return err + } + return nil +} + +func getCaller(skip int) string { + pc, file, line, _ := runtime.Caller(skip) + fun := runtime.FuncForPC(pc) + _, fn := filepath.Split(file) + data, err := ioutil.ReadFile(file) + var codes []string + if err == nil { + lines := bytes.Split(data, []byte{'\n'}) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + if code != "" { + codes = append(codes, code) + } + } + } + funName := fun.Name() + if i := strings.LastIndex(funName, "."); i > -1 { + funName = funName[i+1:] + } + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) +} + +// Deprecated: Using stretchr/testify/assert +func throwFail(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.Fail() + } +} + +func throwFailNow(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.FailNow() + } +} + +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + +func TestSyncDb(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + err := RunSyncdb("default", true, Debug) + throwFail(t, err) + + modelCache.clean() +} + +func TestRegisterModels(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + BootStrap() + + dORM = NewOrm() + dDbBaser = getDbAlias("default").DbBaser +} + +func TestModelSyntax(t *testing.T) { + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFullName(fn) + throwFail(t, AssertIs(ok, true)) + + mi, ok = modelCache.get("user") + throwFail(t, AssertIs(ok, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + } +} + +var DataValues = map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), +} + +func TestDataTypes(t *testing.T) { + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + if name == "JSON" { + continue + } + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = Data{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + d = DataNull{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.JSON, data)) + + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{ID: 2} + err = dORM.Read(&d) + throwFail(t, err) + + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{String: "test", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{ID: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + + // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() + assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) + assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) + assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) +} + +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestCRUD(t *testing.T) { + profile := NewProfile() + profile.Age = 30 + profile.Money = 1234.12 + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + u := &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + + assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) + + user.UserName = "astaxie" + user.Profile = profile + num, err := dORM.Update(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) + + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) + + num, err = dORM.Delete(profile) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + throwFail(t, AssertIs(true, u.Profile == nil)) + + num, err = dORM.Delete(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: 100} + err = dORM.Read(u) + throwFail(t, AssertIs(err, ErrNoRows)) + + ub := UserBig{} + ub.Name = "name" + id, err = dORM.Insert(&ub) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + ub = UserBig{ID: 1} + err = dORM.Read(&ub) + throwFail(t, err) + throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertTestData(t *testing.T) { + var users []*User + + profile := NewProfile() + profile.Age = 28 + profile.Money = 1234.12 + + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 1 + user.IsStaff = false + user.IsActive = true + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + profile = NewProfile() + profile.Age = 30 + profile.Money = 4321.09 + + id, err = dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "astaxie" + user.Email = "astaxie@gmail.com" + user.Password = "password" + user.Status = 2 + user.IsStaff = true + user.IsActive = false + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "nobody" + user.Email = "nobody@gmail.com" + user.Password = "nobody" + user.Status = 3 + user.IsStaff = false + user.IsActive = false + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 4)) + + tags := []*Tag{ + {Name: "golang", BestPost: &Post{ID: 2}}, + {Name: "example"}, + {Name: "format"}, + {Name: "c++"}, + } + + posts := []*Post{ + {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. +With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, + {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. +The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, + } + + comments := []*Comment{ + {Post: posts[0], Content: "a comment"}, + {Post: posts[1], Content: "yes"}, + {Post: posts[1]}, + {Post: posts[1]}, + {Post: posts[2]}, + {Post: posts[2]}, + } + + for _, tag := range tags { + id, err := dORM.Insert(tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, post := range posts { + id, err := dORM.Insert(post) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(post.Tags) + if num > 0 { + nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + + for _, comment := range comments { + id, err := dORM.Insert(comment) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + +} + +func TestCustomField(t *testing.T) { + user := User{ID: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + user.Extra.Name = "beego" + user.Extra.Data = "orm" + _, err = dORM.Update(&user, "Langs", "Extra") + throwFailNow(t, err) + + user = User{ID: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) + + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) +} + +func TestExpr(t *testing.T) { + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable((*User)(nil)) + qs = dORM.QueryTable("User") + qs = dORM.QueryTable("user") + num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + // throwFail(t, err) + // throwFail(t, AssertIs(num, 3)) +} + +func TestOperators(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__iexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__contains", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + var shouldNum int + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__contains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gt", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gte", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("status__lt", Uint(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__lte", Int(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("user_name__startswith", "s").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsSqlite || IsTidb { + shouldNum = 1 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__startswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__istartswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__endswith", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__endswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__iendswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("status__in", 1, 2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSetCond(t *testing.T) { + cond := NewCondition() + cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + + qs := dORM.QueryTable("user") + num, err := qs.SetCond(cond1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) +} + +func TestLimit(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Limit(-1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + num, err = qs.Limit(-1, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Limit(0, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOffset(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOrderBy(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestAll(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("Id").All(&users) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) + + var users2 []User + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").All(&users2) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + var users3 []*User + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + throwFailNow(t, AssertIs(users3 == nil, false)) +} + +func TestOne(t *testing.T) { + var user User + qs := dORM.QueryTable("user") + err := qs.One(&user) + throwFail(t, err) + + user = User{} + err = qs.OrderBy("Id").Limit(1).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "slene")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + user = User{} + err = qs.OrderBy("-Id").Limit(100).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "nobody")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + err = qs.Filter("user_name", "nothing").One(&user) + throwFail(t, AssertIs(err, ErrNoRows)) + +} + +func TestValues(t *testing.T) { + var maps []Params + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) + } + + num, err = qs.Filter("UserName", "slene").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestValuesList(t *testing.T) { + var list []ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").ValuesList(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) + } + + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) + } +} + +func TestValuesFlat(t *testing.T) { + var list ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) + } +} + +func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } + qs := dORM.QueryTable("user") + num, err := qs.Filter("profile__age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var user User + err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "slene").RelatedSel().One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) + + qs = dORM.QueryTable("user_profile") + num, err = qs.Filter("user__username", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var posts []*Post + qs = dORM.QueryTable("post") + num, err = qs.RelatedSel().All(&posts) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 4)) + + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) +} + +func TestReverseQuery(t *testing.T) { + var profile Profile + err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + profile = Profile{} + err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + var user User + err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + var posts []*Post + num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). + Filter("User__UserName", "slene").RelatedSel().All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].User == nil, false)) + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + + var tags []*Tag + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) + throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) +} + +func TestLoadRelated(t *testing.T) { + // load reverse foreign key + user := User{ID: 3} + + err := dORM.Read(&user) + throwFailNow(t, err) + + num, err := dORM.LoadRelated(&user, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) + + num, err = dORM.LoadRelated(&user, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + + num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + // load reverse one to one + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} + num, err = dORM.Update(&profile, "BestPost") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + err = dORM.Read(&profile) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&profile, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&profile, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) + + // load rel one to one + err = dORM.Read(&user) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&user, "Profile") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + num, err = dORM.LoadRelated(&user, "Profile", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) + throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) + + post := Post{ID: 2} + + // load rel foreign key + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&post, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(post.User.Profile == nil, false)) + throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) + + // load rel m2m + post = Post{ID: 2} + + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "Tags") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + + num, err = dORM.LoadRelated(&post, "Tags", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) + + // load reverse m2m + tag := Tag{ID: 1} + + err = dORM.Read(&tag) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&tag, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) + + num, err = dORM.LoadRelated(&tag, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) +} + +func TestQueryM2M(t *testing.T) { + post := Post{ID: 4} + m2m := dORM.QueryM2M(&post, "Tags") + + tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} + tag2 := &Tag{Name: "TestTag3"} + tag3 := []interface{}{&Tag{Name: "TestTag4"}} + + tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} + + for _, tag := range tags { + _, err := dORM.Insert(tag) + throwFailNow(t, err) + } + + num, err := m2m.Add(tag1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 5)) + + num, err = m2m.Remove(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + exist := m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + tag := Tag{Name: "test"} + _, err = dORM.Insert(&tag) + throwFailNow(t, err) + + m2m = dORM.QueryM2M(&tag, "Posts") + + post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} + post2 := &Post{Title: "TestPost3"} + post3 := []interface{}{&Post{Title: "TestPost4"}} + + posts := []interface{}{post1[0], post1[1], post2, post3[0]} + + for _, post := range posts { + p := post.(*Post) + p.User = &User{ID: 1} + _, err := dORM.Insert(post) + throwFailNow(t, err) + } + + num, err = m2m.Add(post1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + num, err = m2m.Remove(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + num, err = dORM.Delete(&tag) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) +} + +func TestQueryRelate(t *testing.T) { + // post := &Post{Id: 2} + + // qs := dORM.QueryRelate(post, "Tags") + // num, err := qs.Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + + // var tags []*Tag + // num, err = qs.All(&tags) + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + // throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) +} + +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + +func TestPrepareInsert(t *testing.T) { + qs := dORM.QueryTable("user") + i, err := qs.PrepareInsert() + throwFailNow(t, err) + + var user User + user.UserName = "testing1" + num, err := i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + user.UserName = "testing2" + num, err = i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + err = i.Close() + throwFail(t, err) + err = i.Close() + throwFail(t, AssertIs(err, ErrStmtClosed)) +} + +func TestRawExec(t *testing.T) { + Q := dDbBaser.TableQuote() + + query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) + res, err := dORM.Raw(query, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) + + res, err = dORM.Raw(query, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) +} + +func TestRawQueryRow(t *testing.T) { + var ( + Boolean bool + Char string + Text string + Time time.Time + Date time.Time + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 + ) + + dataValues := make(map[string]interface{}, len(DataValues)) + + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v + } + + Q := dDbBaser.TableQuote() + + cols := []string{ + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", + } + sep := fmt.Sprintf("%s, %s", Q, Q) + query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) + var id int + values := []interface{}{ + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, + } + err := dORM.Raw(query, 1).QueryRow(values...) + throwFailNow(t, err) + for i, col := range cols { + vu := values[i] + v := reflect.ValueOf(vu).Elem().Interface() + switch col { + case "id": + throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) + case "date": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) + case "datetime": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) + default: + throwFail(t, AssertIs(v, dataValues[col])) + } + } + + var ( + uid int + status *int + pid *int + ) + + cols = []string{ + "id", "Status", "profile_id", + } + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) + throwFail(t, err) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + +func TestQueryRows(t *testing.T) { + Q := dDbBaser.TableQuote() + + var datas []*Data + + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + + ind := reflect.Indirect(reflect.ValueOf(datas[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var ids []int + var usernames []string + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) + + // test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +func TestRawValues(t *testing.T) { + Q := dDbBaser.TableQuote() + + var maps []Params + query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) + num, err := dORM.Raw(query, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(query, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], "slene")) + } + + query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) + var list ParamsList + num, err = dORM.Raw(query).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) + } +} + +func TestRawPrepare(t *testing.T) { + switch { + case IsMysql || IsSqlite: + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid > 0, true)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + + case IsPostgres: + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + } + } +} + +func TestUpdate(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ + "is_staff": true, + "is_active": true, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColAdd, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMinus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMultiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColExcept, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) +} + +func TestDelete(t *testing.T) { + qs := dORM.QueryTable("user_profile") + num, err := qs.Filter("user__user_name", "slene").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestTransaction(t *testing.T) { + // this test worked when database support transaction + + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name__in", names).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + err = o.Begin() + throwFail(t, err) + + tag.Name = "commit" + id, err = o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + +} + +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestReadOrCreate(t *testing.T) { + u := &User{ + UserName: "Kyle", + Email: "kylemcc@gmail.com", + Password: "other_pass", + Status: 7, + IsStaff: false, + IsActive: true, + } + + created, pk, err := dORM.ReadOrCreate(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) + throwFail(t, AssertIs(u.UserName, "Kyle")) + throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) + throwFail(t, AssertIs(u.Password, "other_pass")) + throwFail(t, AssertIs(u.Status, 7)) + throwFail(t, AssertIs(u.IsStaff, false)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) + + nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} + created, pk, err = dORM.ReadOrCreate(nu, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.UserName, u.UserName)) + throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above + throwFail(t, AssertIs(nu.Password, u.Password)) + throwFail(t, AssertIs(nu.Status, u.Status)) + throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) + throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + + dORM.Delete(u) +} + +func TestInLine(t *testing.T) { + name := "inline" + email := "hello@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + il := NewInLine() + il.ID = 1 + err = dORM.Read(il) + throwFail(t, err) + + throwFail(t, AssertIs(il.Name, name)) + throwFail(t, AssertIs(il.Email, email)) + throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) +} + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, _, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err := dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "i_d", + "NO": "n_o", + "NOO": "n_o_o", + "NOOooOOoo": "n_o_ooo_o_ooo", + "OrderNO": "order_n_o", + "tagName": "tag_name", + "tag_Name": "tag__name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} + +func TestInsertOrUpdate(t *testing.T) { + RegisterModel(new(User)) + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + test := User{UserName: "unique_username133"} + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + // test1 + _, err := dORM.InsertOrUpdate(&user1, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + } + // test2 + _, err = dORM.InsertOrUpdate(&user2, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } + // test3 + + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + // test4 - + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + // test5 * + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + // test6 / + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/orm/utils_test.go b/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/LICENSE b/pkg/LICENSE new file mode 100644 index 00000000..5dbd4243 --- /dev/null +++ b/pkg/LICENSE @@ -0,0 +1,13 @@ +Copyright 2014 astaxie + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/test.sh b/scripts/test.sh similarity index 94% rename from test.sh rename to scripts/test.sh index 78928fea..d626d24b 100644 --- a/test.sh +++ b/scripts/test.sh @@ -6,7 +6,7 @@ export ORM_DRIVER=mysql export TZ=UTC export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" -go test ./... +go test ../... # clear all container docker-compose -f test_docker_compose.yaml down diff --git a/test_docker_compose.yaml b/scripts/test_docker_compose.yaml similarity index 100% rename from test_docker_compose.yaml rename to scripts/test_docker_compose.yaml From cfff0f3b46e18d1af7f7b4e9a2a9dad41e451f86 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sat, 25 Jul 2020 00:00:34 +0800 Subject: [PATCH 039/207] fix memory leak of request context --- context/input.go | 32 ++++++++++++++++++++++++++++++-- context/output.go | 6 ++++++ router.go | 4 ++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/context/input.go b/context/input.go index 7b522c36..da066a86 100644 --- a/context/input.go +++ b/context/input.go @@ -323,8 +323,36 @@ func (input *BeegoInput) SetParam(key, val string) { // This function is used to clear parameters so they may be reset between filter // passes. func (input *BeegoInput) ResetParams() { - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] + if len(input.pnames) > 0 { + input.pnames = input.pnames[:0] + } + if len(input.pvalues) > 0 { + input.pvalues = input.pvalues[:0] + } +} + +// ResetData: reset data +func (input *BeegoInput) ResetData() { + input.dataLock.Lock() + if input.data != nil { + input.data = nil + } + input.dataLock.Unlock() +} + +// ResetBody: reset body +func (input *BeegoInput) ResetBody() { + if len(input.RequestBody) > 0 { + input.RequestBody = []byte{} + } +} + +// Clear: clear all data in input +func (input *BeegoInput) Clear() { + input.ResetParams() + input.ResetData() + input.ResetBody() + } // Query returns input data item string by a given string. diff --git a/context/output.go b/context/output.go index 238dcf45..eaa75720 100644 --- a/context/output.go +++ b/context/output.go @@ -50,9 +50,15 @@ func NewOutput() *BeegoOutput { // Reset init BeegoOutput func (output *BeegoOutput) Reset(ctx *Context) { output.Context = ctx + output.Clear() +} + +// Clear: clear all data in output +func (output *BeegoOutput) Clear() { output.Status = 0 } + // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) diff --git a/router.go b/router.go index a993a1af..af0a7ceb 100644 --- a/router.go +++ b/router.go @@ -319,6 +319,10 @@ func (p *ControllerRegister) GetContext() *beecontext.Context { // GiveBackContext put the ctx into pool so that it could be reuse func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + // clear input cached data + ctx.Input.Clear() + // clear output cached data + ctx.Output.Clear() p.pool.Put(ctx) } From 16d71893cdd4e914d740223af2bd983e28a1f96c Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 26 Jul 2020 19:40:13 +0800 Subject: [PATCH 040/207] orm.rawPrepare support FlatParams --- pkg/orm/orm_raw.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 5e05eded..2f214f93 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -32,7 +32,8 @@ func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { if o.closed { return nil, ErrStmtClosed } - return o.stmt.Exec(args...) + flatParams := getFlatParams(nil, args, o.rs.orm.alias.TZ) + return o.stmt.Exec(flatParams...) } func (o *rawPrepare) Close() error { From 2386c9c80d1d38f44f23c134c748810a9f1add0a Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 26 Jul 2020 22:37:42 +0800 Subject: [PATCH 041/207] delete useless if-stmt --- context/input.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/context/input.go b/context/input.go index da066a86..fb01648f 100644 --- a/context/input.go +++ b/context/input.go @@ -323,28 +323,20 @@ func (input *BeegoInput) SetParam(key, val string) { // This function is used to clear parameters so they may be reset between filter // passes. func (input *BeegoInput) ResetParams() { - if len(input.pnames) > 0 { - input.pnames = input.pnames[:0] - } - if len(input.pvalues) > 0 { - input.pvalues = input.pvalues[:0] - } + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] } // ResetData: reset data func (input *BeegoInput) ResetData() { input.dataLock.Lock() - if input.data != nil { - input.data = nil - } + input.data = nil input.dataLock.Unlock() } // ResetBody: reset body func (input *BeegoInput) ResetBody() { - if len(input.RequestBody) > 0 { - input.RequestBody = []byte{} - } + input.RequestBody = []byte{} } // Clear: clear all data in input From 2e7fb81348635b6cba4658d5813f321d5f6cb3ce Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 27 Jul 2020 21:19:34 +0800 Subject: [PATCH 042/207] deprecated orm.go and add NewOrmUsingDB method --- orm/db_alias.go | 23 ++++++++++++++++++++++- pkg/orm/db_alias_test.go | 10 ++++++++++ pkg/orm/orm.go | 10 ++++++---- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index fe6abeb5..d3dbc595 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -41,12 +41,14 @@ const ( type driver string // get type constant int of current driver.. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d driver) Type() DriverType { a, _ := dataBaseCache.get(string(d)) return a.Driver } // get name of current driver +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d driver) Name() string { return string(d) } @@ -111,15 +113,19 @@ type DB struct { stmtDecorators *lru.Cache } +// Begin start a transaction +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) Begin() (*sql.Tx, error) { return d.DB.Begin() } +// BeginTx start a transaction with context and those options +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -151,14 +157,17 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { return sd, nil } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) Prepare(query string) (*sql.Stmt, error) { return d.DB.Prepare(query) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { return d.DB.PrepareContext(ctx, query) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { sd, err := d.getStmtDecorator(query) if err != nil { @@ -169,6 +178,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { return stmt.Exec(args...) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { sd, err := d.getStmtDecorator(query) if err != nil { @@ -179,6 +189,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) return stmt.ExecContext(ctx, args...) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { sd, err := d.getStmtDecorator(query) if err != nil { @@ -189,6 +200,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { return stmt.Query(args...) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { sd, err := d.getStmtDecorator(query) if err != nil { @@ -199,6 +211,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} return stmt.QueryContext(ctx, args...) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { sd, err := d.getStmtDecorator(query) if err != nil { @@ -210,6 +223,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { sd, err := d.getStmtDecorator(query) if err != nil { @@ -319,12 +333,14 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { } // AddAliasWthDB add a aliasName for the drivename +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { _, err := addAliasWthDB(aliasName, driverName, db) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { var ( err error @@ -368,6 +384,7 @@ end: } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func RegisterDriver(driverName string, typ DriverType) error { if t, ok := drivers[driverName]; !ok { drivers[driverName] = typ @@ -380,6 +397,7 @@ func RegisterDriver(driverName string, typ DriverType) error { } // SetDataBaseTZ Change the database default used timezone +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func SetDataBaseTZ(aliasName string, tz *time.Location) error { if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz @@ -390,6 +408,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func SetMaxIdleConns(aliasName string, maxIdleConns int) { al := getDbAlias(aliasName) al.MaxIdleConns = maxIdleConns @@ -397,6 +416,7 @@ func SetMaxIdleConns(aliasName string, maxIdleConns int) { } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) al.MaxOpenConns = maxOpenConns @@ -409,6 +429,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { // GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func GetDB(aliasNames ...string) (*sql.DB, error) { var name string if len(aliasNames) > 0 { diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..81b623c8 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -42,3 +42,13 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.MaxOpenConns, 300) assert.Equal(t, al.ConnMaxLifetime, time.Minute) } + +func TestDBCache(t *testing.T) { + dataBaseCache.add("test1", &alias{}) + dataBaseCache.add("default", &alias{}) + al := dataBaseCache.getDefault() + assert.NotNil(t, al) + al, ok := dataBaseCache.get("test1") + assert.NotNil(t, al) + assert.True(t, ok) +} diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 8ef761f4..07084577 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -591,10 +591,13 @@ func (t *txOrm) Rollback() error { // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once + return NewOrmUsingDB(`default`) +} +// NewOrm create new orm with the name +func NewOrmUsingDB(aliasName string) Ormer { o := new(orm) - name := `default` - if al, ok := dataBaseCache.get(name); ok { + if al, ok := dataBaseCache.get(aliasName); ok { o.alias = al if Debug { o.db = newDbQueryLog(al, al.DB) @@ -602,9 +605,8 @@ func NewOrm() Ormer { o.db = al.DB } } else { - panic(fmt.Errorf(" unknown db alias name `%s`", name)) + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } - return o } From 21f281655d25c9add1aac0bde05bed1b13e25b06 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 27 Jul 2020 21:22:40 +0800 Subject: [PATCH 043/207] remove QueryRelated and QueryRelatedCtx --- pkg/orm/orm.go | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 07084577..3b94ab6c 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -305,6 +305,7 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } + func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) @@ -366,21 +367,6 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s return nums, err } -// return a QuerySeter for related models to md model. -// it can do all, update, delete in QuerySeter. -// example: -// qs := orm.QueryRelated(post,"Tag") -// qs.All(&[]*Tag{}) -// -func (o *ormBase) QueryRelated(md interface{}, name string) QuerySeter { - return o.QueryRelatedWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryRelatedWithCtx(ctx context.Context, md interface{}, name string) QuerySeter { - // is this api needed ? - _, _, _, qs := o.queryRelated(md, name) - return qs -} - // get QuerySeter for related models to md model func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) From 756df9385ff7d6dabed53110931d4d8d3f8e5d09 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 28 Jul 2020 12:57:19 +0800 Subject: [PATCH 044/207] make stmt cache size configurable --- pkg/orm/constant.go | 1 + pkg/orm/db_alias.go | 105 ++++++++++++++++++++++----------------- pkg/orm/db_alias_test.go | 53 ++++++++++++++++++++ pkg/orm/orm.go | 35 +++++++++++-- 4 files changed, 144 insertions(+), 50 deletions(-) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go index 14f40a7b..54550492 100644 --- a/pkg/orm/constant.go +++ b/pkg/orm/constant.go @@ -18,4 +18,5 @@ const ( MaxIdleConnsKey = "MaxIdleConns" MaxOpenConnsKey = "MaxOpenConns" ConnMaxLifetimeKey = "ConnMaxLifetime" + MaxStmtCacheSize = "MaxStmtCacheSize" ) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..a9961649 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -109,8 +109,9 @@ func (ac *_dbCache) getDefault() (al *alias) { type DB struct { *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache + DB *sql.DB + stmtDecorators *lru.Cache + stmtDecoratorsLimit int } var _ dbQuerier = new(DB) @@ -165,16 +166,14 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error } func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) + return d.ExecContext(context.Background(), query, args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if d.stmtDecorators == nil { + return d.DB.ExecContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -185,16 +184,14 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) + return d.QueryContext(context.Background(), query, args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if d.stmtDecorators == nil { + return d.DB.QueryContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -205,24 +202,21 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - + return d.QueryRowContext(context.Background(), query, args...) } func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + if d.stmtDecorators == nil { + return d.DB.QueryRowContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } stmt := sd.getStmt() defer sd.release() - return stmt.QueryRowContext(ctx, args) + return stmt.QueryRowContext(ctx, args...) } type TxDB struct { @@ -345,14 +339,31 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al := new(alias) al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } if dr, ok := drivers[driverName]; ok { @@ -371,12 +382,22 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } + detectTZ(al) + + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + return al, nil } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error { + _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } @@ -388,7 +409,6 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common al *alias ) - kvs := common.NewKVs(params...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -396,23 +416,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common goto end } - al, err = addAliasWthDB(aliasName, driverName, db) + al, err = addAliasWthDB(aliasName, driverName, db, params...) if err != nil { goto end } al.DataSource = dataSource - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - end: if err != nil { if db != nil { @@ -517,9 +527,12 @@ func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { } } -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { +func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { + cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { value.(*stmtDecorator).destroy() }) - return cache + if err != nil { + return nil, err + } + return cache, nil } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..85cdd82f 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -42,3 +42,56 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.MaxOpenConns, 300) assert.Equal(t, al.ConnMaxLifetime, time.Minute) } + +func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: -1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 0, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 1) +} + +func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 841, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) +} + diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 8ef761f4..441fcfc0 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -58,6 +58,8 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/common" + lru "github.com/hashicorp/golang-lru" "os" "reflect" "sync" @@ -609,7 +611,7 @@ func NewOrm() Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { var al *alias if dr, ok := drivers[driverName]; ok { @@ -620,16 +622,41 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { return nil, fmt.Errorf("driver name `%s` have not registered", driverName) } + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } detectTZ(al) + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + o := new(orm) o.alias = al From 54ef4766002facfb8342695aafc44857a0a3652b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 28 Jul 2020 06:28:51 +0000 Subject: [PATCH 045/207] add tag interfaces and remove log.go --- log.go | 127 ------------------------------------ orm/orm.go | 23 +++++++ orm/types.go | 1 + pkg/log.go | 127 ------------------------------------ pkg/orm/model_utils_test.go | 62 ++++++++++++++++++ pkg/orm/models_info_m.go | 2 +- pkg/orm/orm_test.go | 48 +++++--------- pkg/orm/types.go | 55 +++++++++++++++- scripts/test.sh | 6 +- 9 files changed, 160 insertions(+), 291 deletions(-) delete mode 100644 log.go delete mode 100644 pkg/log.go create mode 100644 pkg/orm/model_utils_test.go diff --git a/log.go b/log.go deleted file mode 100644 index cc4c0f81..00000000 --- a/log.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - - "github.com/astaxie/beego/logs" -) - -// Log levels to control the logging output. -// Deprecated: use github.com/astaxie/beego/logs instead. -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -// BeeLogger references the used application logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -var BeeLogger = logs.GetBeeLogger() - -// SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLevel(l int) { - logs.SetLevel(l) -} - -// SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogFuncCall(b bool) { - logs.SetLogFuncCall(b) -} - -// SetLogger sets a new logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogger(adaptername string, config string) error { - return logs.SetLogger(adaptername, config) -} - -// Emergency logs a message at emergency level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Emergency(v ...interface{}) { - logs.Emergency(generateFmtStr(len(v)), v...) -} - -// Alert logs a message at alert level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Alert(v ...interface{}) { - logs.Alert(generateFmtStr(len(v)), v...) -} - -// Critical logs a message at critical level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Critical(v ...interface{}) { - logs.Critical(generateFmtStr(len(v)), v...) -} - -// Error logs a message at error level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Error(v ...interface{}) { - logs.Error(generateFmtStr(len(v)), v...) -} - -// Warning logs a message at warning level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warning(v ...interface{}) { - logs.Warning(generateFmtStr(len(v)), v...) -} - -// Warn compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warn(v ...interface{}) { - logs.Warn(generateFmtStr(len(v)), v...) -} - -// Notice logs a message at notice level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Notice(v ...interface{}) { - logs.Notice(generateFmtStr(len(v)), v...) -} - -// Informational logs a message at info level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Informational(v ...interface{}) { - logs.Informational(generateFmtStr(len(v)), v...) -} - -// Info compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Info(v ...interface{}) { - logs.Info(generateFmtStr(len(v)), v...) -} - -// Debug logs a message at debug level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Debug(v ...interface{}) { - logs.Debug(generateFmtStr(len(v)), v...) -} - -// Trace logs a message at trace level. -// compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Trace(v ...interface{}) { - logs.Trace(generateFmtStr(len(v)), v...) -} - -func generateFmtStr(n int) string { - return strings.Repeat("%v ", n) -} diff --git a/orm/orm.go b/orm/orm.go index 0551b1cd..c7566b9a 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -124,18 +124,21 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } // read data to model +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) @@ -159,6 +162,7 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i } // insert model data to database +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) @@ -183,6 +187,7 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } // insert some models to database +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { var cnt int64 @@ -218,6 +223,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { } // InsertOrUpdate data to database +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) @@ -232,6 +238,7 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // update model to database. // cols set the columns those want to update. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) @@ -239,6 +246,7 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { // delete model in database // cols shows the delete conditions values read from. default is pk +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) @@ -252,6 +260,7 @@ func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { } // create a models to models queryer +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -274,6 +283,7 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) @@ -341,6 +351,7 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int // qs := orm.QueryRelated(post,"Tag") // qs.All(&[]*Tag{}) // +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) @@ -423,6 +434,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { @@ -443,6 +455,8 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { } // switch to another registered database driver by given name. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 +// Using NewOrmUsingDB(name) func (o *orm) Using(name string) error { if o.isTx { panic(fmt.Errorf(" transaction has been start, cannot change db")) @@ -461,10 +475,12 @@ func (o *orm) Using(name string) error { } // begin transaction +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Begin() error { return o.BeginTx(context.Background(), nil) } +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { if o.isTx { return ErrTxHasBegan @@ -484,6 +500,7 @@ func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { } // commit transaction +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Commit() error { if !o.isTx { return ErrTxDone @@ -499,6 +516,7 @@ func (o *orm) Commit() error { } // rollback transaction +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Rollback() error { if !o.isTx { return ErrTxDone @@ -514,16 +532,19 @@ func (o *orm) Rollback() error { } // return a raw query seter for raw sql string. +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Raw(query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } // return current using database Driver +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) Driver() Driver { return driver(o.alias.Name) } // return sql.DBStats for current database +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func (o *orm) DBStats() *sql.DBStats { if o.alias != nil && o.alias.DB != nil { stats := o.alias.DB.DB.Stats() @@ -533,6 +554,7 @@ func (o *orm) DBStats() *sql.DBStats { } // NewOrm create new orm +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func NewOrm() Ormer { BootStrap() // execute only once @@ -545,6 +567,7 @@ func NewOrm() Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query +// Deprecated: using pkg/orm. We will remove this method in v2.1.0 func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { var al *alias diff --git a/orm/types.go b/orm/types.go index 2fd10774..75af7149 100644 --- a/orm/types.go +++ b/orm/types.go @@ -22,6 +22,7 @@ import ( ) // Driver define database driver + type Driver interface { Name() string Type() DriverType diff --git a/pkg/log.go b/pkg/log.go deleted file mode 100644 index cc4c0f81..00000000 --- a/pkg/log.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - - "github.com/astaxie/beego/logs" -) - -// Log levels to control the logging output. -// Deprecated: use github.com/astaxie/beego/logs instead. -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -// BeeLogger references the used application logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -var BeeLogger = logs.GetBeeLogger() - -// SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLevel(l int) { - logs.SetLevel(l) -} - -// SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogFuncCall(b bool) { - logs.SetLogFuncCall(b) -} - -// SetLogger sets a new logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogger(adaptername string, config string) error { - return logs.SetLogger(adaptername, config) -} - -// Emergency logs a message at emergency level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Emergency(v ...interface{}) { - logs.Emergency(generateFmtStr(len(v)), v...) -} - -// Alert logs a message at alert level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Alert(v ...interface{}) { - logs.Alert(generateFmtStr(len(v)), v...) -} - -// Critical logs a message at critical level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Critical(v ...interface{}) { - logs.Critical(generateFmtStr(len(v)), v...) -} - -// Error logs a message at error level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Error(v ...interface{}) { - logs.Error(generateFmtStr(len(v)), v...) -} - -// Warning logs a message at warning level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warning(v ...interface{}) { - logs.Warning(generateFmtStr(len(v)), v...) -} - -// Warn compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warn(v ...interface{}) { - logs.Warn(generateFmtStr(len(v)), v...) -} - -// Notice logs a message at notice level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Notice(v ...interface{}) { - logs.Notice(generateFmtStr(len(v)), v...) -} - -// Informational logs a message at info level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Informational(v ...interface{}) { - logs.Informational(generateFmtStr(len(v)), v...) -} - -// Info compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Info(v ...interface{}) { - logs.Info(generateFmtStr(len(v)), v...) -} - -// Debug logs a message at debug level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Debug(v ...interface{}) { - logs.Debug(generateFmtStr(len(v)), v...) -} - -// Trace logs a message at trace level. -// compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Trace(v ...interface{}) { - logs.Trace(generateFmtStr(len(v)), v...) -} - -func generateFmtStr(n int) string { - return strings.Repeat("%v ", n) -} diff --git a/pkg/orm/model_utils_test.go b/pkg/orm/model_utils_test.go new file mode 100644 index 00000000..ea38d90a --- /dev/null +++ b/pkg/orm/model_utils_test.go @@ -0,0 +1,62 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Interface struct { + Id int + Name string + + Index1 string + Index2 string + + Unique1 string + Unique2 string +} + +func (i *Interface) TableIndex() [][]string { + return [][]string{{"index1"}, {"index2"}} +} + +func (i *Interface) TableUnique() [][]string { + return [][]string{{"unique1"}, {"unique2"}} +} + +func (i *Interface) TableName() string { + return "INTERFACE_" +} + +func (i *Interface) TableEngine() string { + return "innodb" +} + +func TestDbBase_GetTables(t *testing.T) { + RegisterModel(&Interface{}) + mi, ok := modelCache.get("INTERFACE_") + assert.True(t, ok) + assert.NotNil(t, mi) + + engine := getTableEngine(mi.addrField) + assert.Equal(t, "innodb", engine) + uniques := getTableUnique(mi.addrField) + assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques) + indexes := getTableIndex(mi.addrField) + assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes) +} diff --git a/pkg/orm/models_info_m.go b/pkg/orm/models_info_m.go index a4d733b6..c450239c 100644 --- a/pkg/orm/models_info_m.go +++ b/pkg/orm/models_info_m.go @@ -29,7 +29,7 @@ type modelInfo struct { model interface{} fields *fields manual bool - addrField reflect.Value //store the original struct value + addrField reflect.Value // store the original struct value uniques []string isThrough bool } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 54ecc0fd..e3dafecd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -297,16 +297,13 @@ func TestDataTypes(t *testing.T) { vu := e.Interface() switch name { case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + break + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) } } @@ -1662,18 +1659,14 @@ func TestRawQueryRow(t *testing.T) { switch col { case "id": throwFail(t, AssertIs(id, 1)) + break case "time": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) case "datetime": v = v.(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) + assert.True(t, v.(time.Time).Sub(value) <= time.Second) + break default: throwFail(t, AssertIs(v, dataValues[col])) } @@ -1746,16 +1739,13 @@ func TestQueryRows(t *testing.T) { vu := e.Interface() switch name { case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + break + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) } var datas2 []Data @@ -1773,16 +1763,14 @@ func TestQueryRows(t *testing.T) { vu := e.Interface() switch name { case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + break + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) + } var ids []int @@ -2193,8 +2181,8 @@ func TestInLine(t *testing.T) { throwFail(t, AssertIs(il.Name, name)) throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) + assert.True(t, il.Created.In(DefaultTimeLoc).Sub(inline.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, il.Updated.In(DefaultTimeLoc).Sub(inline.Updated.In(DefaultTimeLoc)) <= time.Second) } func TestInLineOneToOne(t *testing.T) { diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 8255d93e..cb0f97cc 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -21,6 +21,58 @@ import ( "time" ) +// TableNaming is usually used by model +// when you custom your table name, please implement this interfaces +// for example: +// type User struct { +// ... +// } +// func (u *User) TableName() string { +// return "USER_TABLE" +// } +type TableNameI interface { + TableName() string +} + +// TableEngineI is usually used by model +// when you want to use specific engine, like myisam, you can implement this interface +// for example: +// type User struct { +// ... +// } +// func (u *User) TableEngine() string { +// return "myisam" +// } +type TableEngineI interface { + TableEngine() string +} + +// TableIndexI is usually used by model +// when you want to create indexes, you can implement this interface +// for example: +// type User struct { +// ... +// } +// func (u *User) TableIndex() [][]string { +// return [][]string{{"Name"}} +// } +type TableIndexI interface { + TableIndex() [][]string +} + +// TableUniqueI is usually used by model +// when you want to create unique indexes, you can implement this interface +// for example: +// type User struct { +// ... +// } +// func (u *User) TableUnique() [][]string { +// return [][]string{{"Email"}} +// } +type TableUniqueI interface { + TableUnique() [][]string +} + // Driver define database driver type Driver interface { Name() string @@ -145,9 +197,6 @@ type DQL interface { QueryTable(ptrStructOrTableName interface{}) QuerySeter QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter - // switch to another registered database driver by given name. - // Using(name string) error - DBStats() *sql.DBStats } diff --git a/scripts/test.sh b/scripts/test.sh index d626d24b..473a7066 100644 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,14 +1,14 @@ #!/bin/bash -docker-compose -f test_docker_compose.yaml up -d +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" up -d export ORM_DRIVER=mysql export TZ=UTC export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" -go test ../... +go test "$(pwd)/..." # clear all container -docker-compose -f test_docker_compose.yaml down +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" down From e8facd28f544bb4edbe2a533ca62e6878e025958 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 28 Jul 2020 17:37:36 +0800 Subject: [PATCH 046/207] wrap kv --- pkg/common/kv.go | 23 +++++++++++--- pkg/common/kv_test.go | 2 +- pkg/orm/constant.go | 21 ------------ pkg/orm/db_alias.go | 11 ++++--- pkg/orm/db_alias_test.go | 16 +++------- pkg/orm/db_hints.go | 68 +++++++++++++++++++++++++++++++++++++++ pkg/orm/db_hints_test.go | 69 ++++++++++++++++++++++++++++++++++++++++ pkg/orm/models_test.go | 6 +--- 8 files changed, 168 insertions(+), 48 deletions(-) delete mode 100644 pkg/orm/constant.go create mode 100644 pkg/orm/db_hints.go create mode 100644 pkg/orm/db_hints_test.go diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 508e6b5c..86a50132 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -14,14 +14,29 @@ package common -// KV is common structure to store key-value data. +type KV interface { + GetKey() interface{} + GetValue() interface{} +} + +// SimpleKV is common structure to store key-value data. // when you need something like Pair, you can use this -type KV struct { +type SimpleKV struct { Key interface{} Value interface{} } -// KVs will store KV collection as map +var _ KV = new(SimpleKV) + +func (s *SimpleKV) GetKey() interface{} { + return s.Key +} + +func (s *SimpleKV) GetValue() interface{} { + return s.Value +} + +// KVs will store SimpleKV collection as map type KVs struct { kvs map[interface{}]interface{} } @@ -63,7 +78,7 @@ func NewKVs(kvs ...KV) *KVs { kvs: make(map[interface{}]interface{}, len(kvs)), } for _, kv := range kvs { - res.kvs[kv.Key] = kv.Value + res.kvs[kv.GetKey()] = kv.GetValue() } return res } diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index 45adf5ff..275c6753 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -22,7 +22,7 @@ import ( func TestKVs(t *testing.T) { key := "my-key" - kvs := NewKVs(KV{ + kvs := NewKVs(&SimpleKV{ Key: key, Value: 12, }) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go deleted file mode 100644 index 14f40a7b..00000000 --- a/pkg/orm/constant.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -const ( - MaxIdleConnsKey = "MaxIdleConns" - MaxOpenConnsKey = "MaxOpenConns" - ConnMaxLifetimeKey = "ConnMaxLifetime" -) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..53c668ae 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -381,14 +382,14 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { var ( err error db *sql.DB al *alias ) - kvs := common.NewKVs(params...) + kvs := common.NewKVs(hints...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -405,11 +406,11 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common detectTZ(al) - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + }).IfContains(connMaxLifetimeKey, func(value interface{}) { SetConnMaxLifetime(al.Name, value.(time.Duration)) }) diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..28dd2e6b 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -19,21 +19,13 @@ import ( "time" "github.com/stretchr/testify/assert" - - "github.com/astaxie/beego/pkg/common" ) func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }, common.KV{ - Key: MaxOpenConnsKey, - Value: 300, - }, common.KV{ - Key: ConnMaxLifetimeKey, - Value: time.Minute, - }) + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go new file mode 100644 index 00000000..8900d599 --- /dev/null +++ b/pkg/orm/db_hints.go @@ -0,0 +1,68 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/common" + "time" +) + +type Hint struct { + key interface{} + value interface{} +} + +var _ common.KV = new(Hint) + +// GetKey return key +func (s *Hint) GetKey() interface{} { + return s.key +} + +// GetValue return value +func (s *Hint) GetValue() interface{} { + return s.value +} + +const ( + maxIdleConnectionsKey = "MaxIdleConnections" + maxOpenConnectionsKey = "MaxOpenConnections" + connMaxLifetimeKey = "ConnMaxLifetime" +) + +var _ common.KV = new(Hint) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(v int) *Hint { + return NewHint(maxIdleConnectionsKey, v) +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(v int) *Hint { + return NewHint(maxOpenConnectionsKey, v) +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) *Hint { + return NewHint(connMaxLifetimeKey, v) +} + +// NewHint return a hint +func NewHint(key interface{}, value interface{}) *Hint { + return &Hint{ + key: key, + value: value, + } +} diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go new file mode 100644 index 00000000..9b62a730 --- /dev/null +++ b/pkg/orm/db_hints_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewHint_time(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_int(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_float(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestMaxOpenConnections(t *testing.T) { + i := 887423 + hint := MaxOpenConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey) +} + +func TestConnMaxLifetime(t *testing.T) { + i := time.Hour + hint := ConnMaxLifetime(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), connMaxLifetimeKey) +} + +func TestMaxIdleConnections(t *testing.T) { + i := 42316 + hint := MaxIdleConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 4c00050d..ae166dc7 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -28,7 +28,6 @@ import ( // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" - "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -489,10 +488,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) From e87de70c6deb0eb13f5b48596511240da9e2551e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Wed, 29 Jul 2020 00:45:41 +0800 Subject: [PATCH 047/207] adapt wrapping kv --- pkg/orm/db_alias.go | 52 ++++++++++++++++++++++++++-------------- pkg/orm/db_alias_test.go | 20 ++++------------ pkg/orm/db_hints.go | 6 +++++ pkg/orm/db_hints_test.go | 7 ++++++ pkg/orm/orm.go | 48 +++---------------------------------- 5 files changed, 54 insertions(+), 79 deletions(-) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 336ec54b..5f1e3ea3 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -341,12 +340,30 @@ func detectTZ(al *alias) { } func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + if _, ok := dataBaseCache.get(aliasName); ok { + return nil, existErr + } + + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + if !dataBaseCache.add(aliasName, al) { + return nil, existErr + } + + return al, nil +} + +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ kvs := common.NewKVs(params...) var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) if maxStmtCacheSize > 0 { _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) if errC != nil { @@ -379,18 +396,20 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - detectTZ(al) - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxIdleConns(al, m) + } + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxOpenConns(al, m) + } + }).IfContains(connMaxLifetimeKey, func(value interface{}) { + if m, ok := value.(time.Duration); ok { + SetConnMaxLifetime(al, m) + } }) return al, nil @@ -458,21 +477,18 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) +func SetMaxIdleConns(al *alias, maxIdleConns int) { al.MaxIdleConns = maxIdleConns al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) +func SetMaxOpenConns(al *alias, maxOpenConns int) { al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) } -func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { - al := getDbAlias(aliasName) +func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { al.ConnMaxLifetime = lifeTime al.DB.DB.SetConnMaxLifetime(lifeTime) } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index c8b4aad1..ebf93a86 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -37,10 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: -1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -50,10 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 0, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -63,10 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -76,10 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 841, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go index 8900d599..551c7357 100644 --- a/pkg/orm/db_hints.go +++ b/pkg/orm/db_hints.go @@ -40,6 +40,7 @@ const ( maxIdleConnectionsKey = "MaxIdleConnections" maxOpenConnectionsKey = "MaxOpenConnections" connMaxLifetimeKey = "ConnMaxLifetime" + maxStmtCacheSizeKey = "MaxStmtCacheSize" ) var _ common.KV = new(Hint) @@ -59,6 +60,11 @@ func ConnMaxLifetime(v time.Duration) *Hint { return NewHint(connMaxLifetimeKey, v) } +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) *Hint { + return NewHint(maxStmtCacheSizeKey, v) +} + // NewHint return a hint func NewHint(key interface{}, value interface{}) *Hint { return &Hint{ diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go index 9b62a730..13f8ccde 100644 --- a/pkg/orm/db_hints_test.go +++ b/pkg/orm/db_hints_test.go @@ -67,3 +67,10 @@ func TestMaxIdleConnections(t *testing.T) { assert.Equal(t, hint.GetValue(), i) assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) } + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) +} \ No newline at end of file diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 441fcfc0..b2f1e693 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -59,10 +59,8 @@ import ( "errors" "fmt" "github.com/astaxie/beego/pkg/common" - lru "github.com/hashicorp/golang-lru" "os" "reflect" - "sync" "time" "github.com/astaxie/beego/logs" @@ -612,51 +610,11 @@ func NewOrm() Ormer { // NewOrmWithDB create a new ormer object with specify *sql.DB for query func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err } - kvs := common.NewKVs(params...) - - var stmtCache *lru.Cache - var stmtCacheSize int - - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) - if errC != nil { - return nil, errC - } else { - stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize - } - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } - - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - o := new(orm) o.alias = al From 15f04b8da467a5dc4015e58ce569eea725cd892d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=B1=AA=E8=B4=B5?= Date: Wed, 29 Jul 2020 21:57:16 +0800 Subject: [PATCH 048/207] add env BEEGO_CONFIG_PATH --- config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config.go b/config.go index b6c9a99c..92aa3bbd 100644 --- a/config.go +++ b/config.go @@ -150,6 +150,9 @@ func init() { filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" } appConfigPath = filepath.Join(WorkPath, "conf", filename) + if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { + appConfigPath = configPath + } if !utils.FileExists(appConfigPath) { appConfigPath = filepath.Join(AppPath, "conf", filename) if !utils.FileExists(appConfigPath) { From aa06a104932a82fb6ce630303f78d738c74ffba4 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 29 Jul 2020 21:56:19 +0800 Subject: [PATCH 049/207] uing pkg module --- grace/grace.go | 9 +++++ grace/server.go | 6 ++++ pkg/admin.go | 8 ++--- pkg/admin_test.go | 2 +- pkg/app.go | 9 ++--- pkg/cache/memcache/memcache.go | 3 +- pkg/cache/memcache/memcache_test.go | 6 ++-- pkg/cache/redis/redis.go | 6 ++-- pkg/cache/redis/redis_test.go | 7 ++-- pkg/cache/ssdb/ssdb.go | 2 +- pkg/cache/ssdb/ssdb_test.go | 2 +- pkg/config.go | 10 +++--- pkg/config/env/env.go | 2 +- pkg/config/{ => ini}/ini.go | 14 ++++---- pkg/config/{ => ini}/ini_test.go | 8 +++-- pkg/config/{ => json}/json.go | 14 ++++---- pkg/config/{ => json}/json_test.go | 8 +++-- pkg/config/xml/xml.go | 2 +- pkg/config/xml/xml_test.go | 2 +- pkg/config/yaml/yaml.go | 2 +- pkg/config/yaml/yaml_test.go | 2 +- pkg/config_test.go | 8 ++--- pkg/context/context.go | 2 +- pkg/context/input.go | 2 +- pkg/context/param/conv.go | 4 +-- pkg/controller.go | 6 ++-- pkg/controller_test.go | 2 +- pkg/doc.go | 2 +- pkg/error.go | 4 +-- pkg/filter.go | 2 +- pkg/filter_test.go | 2 +- pkg/hooks.go | 6 ++-- pkg/log.go | 34 +++++++++---------- pkg/logs/alils/alils.go | 2 +- pkg/logs/es/es.go | 2 +- pkg/metric/prometheus.go | 4 +-- pkg/metric/prometheus_test.go | 2 +- pkg/migration/ddl.go | 2 +- pkg/migration/migration.go | 4 +-- pkg/namespace.go | 28 +++++++-------- pkg/namespace_test.go | 2 +- pkg/orm/orm.go | 2 +- pkg/parser.go | 10 +++--- pkg/plugins/apiauth/apiauth.go | 4 +-- pkg/plugins/auth/basic.go | 4 +-- pkg/plugins/authz/authz.go | 4 +-- pkg/plugins/authz/authz_test.go | 6 ++-- pkg/plugins/cors/cors.go | 4 +-- pkg/plugins/cors/cors_test.go | 4 +-- pkg/policy.go | 2 +- pkg/router.go | 10 +++--- pkg/router_test.go | 4 +-- pkg/session/couchbase/sess_couchbase.go | 2 +- pkg/session/ledis/ledis_session.go | 2 +- pkg/session/memcache/sess_memcache.go | 2 +- pkg/session/mysql/sess_mysql.go | 2 +- pkg/session/postgres/sess_postgresql.go | 2 +- pkg/session/redis/sess_redis.go | 2 +- pkg/session/redis_cluster/redis_cluster.go | 2 +- .../redis_sentinel/sess_redis_sentinel.go | 2 +- .../sess_redis_sentinel_test.go | 2 +- pkg/session/sess_utils.go | 2 +- pkg/session/ssdb/sess_ssdb.go | 2 +- pkg/staticfile.go | 4 +-- pkg/template.go | 4 +-- pkg/template_test.go | 2 +- pkg/testing/assertions.go | 15 -------- pkg/testing/client.go | 4 +-- pkg/tree.go | 4 +-- pkg/tree_test.go | 2 +- pkg/utils/captcha/captcha.go | 10 +++--- pkg/utils/captcha/image_test.go | 2 +- pkg/utils/pagination/controller.go | 2 +- pkg/utils/pagination/doc.go | 2 +- pkg/validation/validators.go | 2 +- 75 files changed, 192 insertions(+), 181 deletions(-) rename pkg/config/{ => ini}/ini.go (97%) rename pkg/config/{ => ini}/ini_test.go (95%) rename pkg/config/{ => json}/json.go (95%) rename pkg/config/{ => json}/json_test.go (97%) delete mode 100644 pkg/testing/assertions.go diff --git a/grace/grace.go b/grace/grace.go index fb0cb7bb..39d067fd 100644 --- a/grace/grace.go +++ b/grace/grace.go @@ -54,16 +54,22 @@ import ( const ( // PreSignal is the position to add filter before signal + // Deprecated: using pkg/grace, we will delete this in v2.1.0 PreSignal = iota // PostSignal is the position to add filter after signal + // Deprecated: using pkg/grace, we will delete this in v2.1.0 PostSignal // StateInit represent the application inited + // Deprecated: using pkg/grace, we will delete this in v2.1.0 StateInit // StateRunning represent the application is running + // Deprecated: using pkg/grace, we will delete this in v2.1.0 StateRunning // StateShuttingDown represent the application is shutting down + // Deprecated: using pkg/grace, we will delete this in v2.1.0 StateShuttingDown // StateTerminate represent the application is killed + // Deprecated: using pkg/grace, we will delete this in v2.1.0 StateTerminate ) @@ -106,6 +112,7 @@ func init() { } // NewServer returns a new graceServer. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func NewServer(addr string, handler http.Handler) (srv *Server) { regLock.Lock() defer regLock.Unlock() @@ -154,12 +161,14 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { } // ListenAndServe refer http.ListenAndServe +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func ListenAndServe(addr string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServe() } // ListenAndServeTLS refer http.ListenAndServeTLS +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServeTLS(certFile, keyFile) diff --git a/grace/server.go b/grace/server.go index 008a6171..cd659f82 100644 --- a/grace/server.go +++ b/grace/server.go @@ -18,6 +18,7 @@ import ( ) // Server embedded http.Server +// Deprecated: using pkg/grace, we will delete this in v2.1.0 type Server struct { *http.Server ln net.Listener @@ -32,6 +33,7 @@ type Server struct { // Serve accepts incoming connections on the Listener l, // creating a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func (srv *Server) Serve() (err error) { srv.state = StateRunning defer func() { srv.state = StateTerminate }() @@ -55,6 +57,7 @@ func (srv *Server) Serve() (err error) { // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // to handle requests on incoming connections. If srv.Addr is blank, ":http" is // used. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func (srv *Server) ListenAndServe() (err error) { addr := srv.Addr if addr == "" { @@ -94,6 +97,7 @@ func (srv *Server) ListenAndServe() (err error) { // CA's certificate. // // If srv.Addr is blank, ":https" is used. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { addr := srv.Addr if addr == "" { @@ -140,6 +144,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls // Serve to handle requests on incoming mutual TLS connections. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { addr := srv.Addr if addr == "" { @@ -340,6 +345,7 @@ func (srv *Server) fork() (err error) { } // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +// Deprecated: using pkg/grace, we will delete this in v2.1.0 func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { if ppFlag != PreSignal && ppFlag != PostSignal { err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") diff --git a/pkg/admin.go b/pkg/admin.go index db52647e..4d8b256f 100644 --- a/pkg/admin.go +++ b/pkg/admin.go @@ -27,10 +27,10 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/grace" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/toolbox" + "github.com/astaxie/beego/pkg/utils" ) // BeeAdminApp is the default adminApp used by admin module. diff --git a/pkg/admin_test.go b/pkg/admin_test.go index 3f3612e4..e7eae771 100644 --- a/pkg/admin_test.go +++ b/pkg/admin_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/pkg/toolbox" ) type SampleDatabaseCheck struct { diff --git a/pkg/app.go b/pkg/app.go index f3fe6f7b..eb672b1f 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -27,10 +27,11 @@ import ( "strings" "time" - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" "golang.org/x/crypto/acme/autocert" + + "github.com/astaxie/beego/pkg/grace" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/utils" ) var ( @@ -348,7 +349,7 @@ func findAndRemoveSingleTree(entryPointTree *Tree) { // func (b *BankAccount)Mapping(){ // b.Mapping("ShowAccount" , b.ShowAccount) // b.Mapping("ModifyAccount", b.ModifyAccount) -//} +// } // // //@router /account/:id [get] // func (b *BankAccount) ShowAccount(){ diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go index 19116bfa..b08596eb 100644 --- a/pkg/cache/memcache/memcache.go +++ b/pkg/cache/memcache/memcache.go @@ -35,8 +35,9 @@ import ( "strings" "time" - "github.com/astaxie/beego/cache" "github.com/bradfitz/gomemcache/memcache" + + "github.com/astaxie/beego/pkg/cache" ) // Cache Memcache adapter. diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go index d9129b69..b7dad8fc 100644 --- a/pkg/cache/memcache/memcache_test.go +++ b/pkg/cache/memcache/memcache_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/pkg/cache" ) func TestMemcacheCache(t *testing.T) { @@ -70,7 +70,7 @@ func TestMemcacheCache(t *testing.T) { t.Error("delete err") } - //test string + // test string if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } @@ -82,7 +82,7 @@ func TestMemcacheCache(t *testing.T) { t.Error("get err") } - //test GetMulti + // test GetMulti if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go index d8737b3c..a5fec591 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/cache/redis/redis.go @@ -34,12 +34,12 @@ import ( "errors" "fmt" "strconv" + "strings" "time" "github.com/gomodule/redigo/redis" - "github.com/astaxie/beego/cache" - "strings" + "github.com/astaxie/beego/pkg/cache" ) var ( @@ -56,7 +56,7 @@ type Cache struct { password string maxIdle int - //the timeout to a value less than the redis server's timeout. + // the timeout to a value less than the redis server's timeout. timeout time.Duration } diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go index 60a19180..8de331ab 100644 --- a/pkg/cache/redis/redis_test.go +++ b/pkg/cache/redis/redis_test.go @@ -19,9 +19,10 @@ import ( "testing" "time" - "github.com/astaxie/beego/cache" "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/cache" ) func TestRedisCache(t *testing.T) { @@ -70,7 +71,7 @@ func TestRedisCache(t *testing.T) { t.Error("delete err") } - //test string + // test string if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } @@ -82,7 +83,7 @@ func TestRedisCache(t *testing.T) { t.Error("get err") } - //test GetMulti + // test GetMulti if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go index fa2ce04b..62a63c60 100644 --- a/pkg/cache/ssdb/ssdb.go +++ b/pkg/cache/ssdb/ssdb.go @@ -9,7 +9,7 @@ import ( "github.com/ssdb/gossdb/ssdb" - "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/pkg/cache" ) // Cache SSDB adapter diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/cache/ssdb/ssdb_test.go index dd474960..7390ea94 100644 --- a/pkg/cache/ssdb/ssdb_test.go +++ b/pkg/cache/ssdb/ssdb_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/pkg/cache" ) func TestSsdbcacheCache(t *testing.T) { diff --git a/pkg/config.go b/pkg/config.go index b6c9a99c..2a5dec25 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -22,11 +22,11 @@ import ( "runtime" "strings" - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/config" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/utils" ) // Config is the main struct for BConfig diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go index 34f094fe..7c729780 100644 --- a/pkg/config/env/env.go +++ b/pkg/config/env/env.go @@ -21,7 +21,7 @@ import ( "os" "strings" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/utils" ) var env *utils.BeeMap diff --git a/pkg/config/ini.go b/pkg/config/ini/ini.go similarity index 97% rename from pkg/config/ini.go rename to pkg/config/ini/ini.go index 002e5e05..a3c6462d 100644 --- a/pkg/config/ini.go +++ b/pkg/config/ini/ini.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package ini import ( "bufio" @@ -26,6 +26,8 @@ import ( "strconv" "strings" "sync" + + "github.com/astaxie/beego/pkg/config" ) var ( @@ -45,7 +47,7 @@ type IniConfig struct { } // Parse creates a new Config and parses the file configuration from the named file. -func (ini *IniConfig) Parse(name string) (Configer, error) { +func (ini *IniConfig) Parse(name string) (config.Configer, error) { return ini.parseFile(name) } @@ -195,7 +197,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e val = bytes.Trim(val, `"`) } - cfg.data[section][key] = ExpandValueEnv(string(val)) + cfg.data[section][key] = config.ExpandValueEnv(string(val)) if comment.Len() > 0 { cfg.keyComment[section+"."+key] = comment.String() comment.Reset() @@ -208,7 +210,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e // ParseData parse ini the data // When include other.conf,other.conf is either absolute directory // or under beego in default temporary directory(/tmp/beego[-username]). -func (ini *IniConfig) ParseData(data []byte) (Configer, error) { +func (ini *IniConfig) ParseData(data []byte) (config.Configer, error) { dir := "beego" currentUser, err := user.Current() if err == nil { @@ -233,7 +235,7 @@ type IniConfigContainer struct { // Bool returns the boolean value for a given key. func (c *IniConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getdata(key)) + return config.ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. @@ -500,5 +502,5 @@ func (c *IniConfigContainer) getdata(key string) string { } func init() { - Register("ini", &IniConfig{}) + config.Register("ini", &IniConfig{}) } diff --git a/pkg/config/ini_test.go b/pkg/config/ini/ini_test.go similarity index 95% rename from pkg/config/ini_test.go rename to pkg/config/ini/ini_test.go index ffcdb294..70f1091d 100644 --- a/pkg/config/ini_test.go +++ b/pkg/config/ini/ini_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package ini import ( "fmt" @@ -20,6 +20,8 @@ import ( "os" "strings" "testing" + + "github.com/astaxie/beego/pkg/config" ) func TestIni(t *testing.T) { @@ -92,7 +94,7 @@ password = ${GOPATH} } f.Close() defer os.Remove("testini.conf") - iniconf, err := NewConfig("ini", "testini.conf") + iniconf, err := config.NewConfig("ini", "testini.conf") if err != nil { t.Fatal(err) } @@ -165,7 +167,7 @@ httpport=8080 name=mysql ` ) - cfg, err := NewConfigData("ini", []byte(inicontext)) + cfg, err := config.NewConfigData("ini", []byte(inicontext)) if err != nil { t.Fatal(err) } diff --git a/pkg/config/json.go b/pkg/config/json/json.go similarity index 95% rename from pkg/config/json.go rename to pkg/config/json/json.go index c4ef25cd..49bd38ff 100644 --- a/pkg/config/json.go +++ b/pkg/config/json/json.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package json import ( "encoding/json" @@ -23,6 +23,8 @@ import ( "strconv" "strings" "sync" + + "github.com/astaxie/beego/pkg/config" ) // JSONConfig is a json config parser and implements Config interface. @@ -30,7 +32,7 @@ type JSONConfig struct { } // Parse returns a ConfigContainer with parsed json config map. -func (js *JSONConfig) Parse(filename string) (Configer, error) { +func (js *JSONConfig) Parse(filename string) (config.Configer, error) { file, err := os.Open(filename) if err != nil { return nil, err @@ -45,7 +47,7 @@ func (js *JSONConfig) Parse(filename string) (Configer, error) { } // ParseData returns a ConfigContainer with json string -func (js *JSONConfig) ParseData(data []byte) (Configer, error) { +func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { x := &JSONConfigContainer{ data: make(map[string]interface{}), } @@ -59,7 +61,7 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) { x.data["rootArray"] = wrappingArray } - x.data = ExpandValueEnvForMap(x.data) + x.data = config.ExpandValueEnvForMap(x.data) return x, nil } @@ -75,7 +77,7 @@ type JSONConfigContainer struct { func (c *JSONConfigContainer) Bool(key string) (bool, error) { val := c.getData(key) if val != nil { - return ParseBool(val) + return config.ParseBool(val) } return false, fmt.Errorf("not exist key: %q", key) } @@ -265,5 +267,5 @@ func (c *JSONConfigContainer) getData(key string) interface{} { } func init() { - Register("json", &JSONConfig{}) + config.Register("json", &JSONConfig{}) } diff --git a/pkg/config/json_test.go b/pkg/config/json/json_test.go similarity index 97% rename from pkg/config/json_test.go rename to pkg/config/json/json_test.go index 16f42409..da87986f 100644 --- a/pkg/config/json_test.go +++ b/pkg/config/json/json_test.go @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package json import ( "fmt" "os" "testing" + + "github.com/astaxie/beego/pkg/config" ) func TestJsonStartsWithArray(t *testing.T) { @@ -43,7 +45,7 @@ func TestJsonStartsWithArray(t *testing.T) { } f.Close() defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + jsonconf, err := config.NewConfig("json", "testjsonWithArray.conf") if err != nil { t.Fatal(err) } @@ -143,7 +145,7 @@ func TestJson(t *testing.T) { } f.Close() defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") + jsonconf, err := config.NewConfig("json", "testjson.conf") if err != nil { t.Fatal(err) } diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go index 494242d3..b1cce5c8 100644 --- a/pkg/config/xml/xml.go +++ b/pkg/config/xml/xml.go @@ -39,7 +39,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/config" + "github.com/astaxie/beego/pkg/config" "github.com/beego/x2j" ) diff --git a/pkg/config/xml/xml_test.go b/pkg/config/xml/xml_test.go index 346c866e..b7828933 100644 --- a/pkg/config/xml/xml_test.go +++ b/pkg/config/xml/xml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/config" + "github.com/astaxie/beego/pkg/config" ) func TestXML(t *testing.T) { diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go index a5644c7b..3dcb45fd 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/config/yaml/yaml.go @@ -40,7 +40,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/config" + "github.com/astaxie/beego/pkg/config" "github.com/beego/goyaml2" ) diff --git a/pkg/config/yaml/yaml_test.go b/pkg/config/yaml/yaml_test.go index 49cc1d1e..0e76457f 100644 --- a/pkg/config/yaml/yaml_test.go +++ b/pkg/config/yaml/yaml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/config" + "github.com/astaxie/beego/pkg/config" ) func TestYaml(t *testing.T) { diff --git a/pkg/config_test.go b/pkg/config_test.go index 5f71f1c3..c810a9e3 100644 --- a/pkg/config_test.go +++ b/pkg/config_test.go @@ -19,7 +19,7 @@ import ( "reflect" "testing" - "github.com/astaxie/beego/config" + beeJson "github.com/astaxie/beego/pkg/config/json" ) func TestDefaults(t *testing.T) { @@ -35,7 +35,7 @@ func TestDefaults(t *testing.T) { func TestAssignConfig_01(t *testing.T) { _BConfig := &Config{} _BConfig.AppName = "beego_test" - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) assignSingleConfig(_BConfig, ac) if _BConfig.AppName != "beego_json" { @@ -73,7 +73,7 @@ func TestAssignConfig_02(t *testing.T) { configMap["SessionProviderConfig"] = "file" configMap["FileLineNum"] = true - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} bs, _ = json.Marshal(configMap) ac, _ := jcf.ParseData(bs) @@ -109,7 +109,7 @@ func TestAssignConfig_02(t *testing.T) { } func TestAssignConfig_03(t *testing.T) { - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) ac.Set("AppName", "test_app") ac.Set("RunMode", "online") diff --git a/pkg/context/context.go b/pkg/context/context.go index de248ed2..9326fa28 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -35,7 +35,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/utils" ) //commonly used mime-types diff --git a/pkg/context/input.go b/pkg/context/input.go index 385549c1..04347e04 100644 --- a/pkg/context/input.go +++ b/pkg/context/input.go @@ -29,7 +29,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" ) // Regexes for checking the accept headers diff --git a/pkg/context/param/conv.go b/pkg/context/param/conv.go index c200e008..d96f964c 100644 --- a/pkg/context/param/conv.go +++ b/pkg/context/param/conv.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + beecontext "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" ) // ConvertParams converts http method params to values that will be passed to the method controller as arguments diff --git a/pkg/controller.go b/pkg/controller.go index 0e8853b3..f3989a76 100644 --- a/pkg/controller.go +++ b/pkg/controller.go @@ -28,9 +28,9 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/context/param" + "github.com/astaxie/beego/pkg/session" ) var ( diff --git a/pkg/controller_test.go b/pkg/controller_test.go index 1e53416d..f51cc109 100644 --- a/pkg/controller_test.go +++ b/pkg/controller_test.go @@ -19,7 +19,7 @@ import ( "strconv" "testing" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" "os" "path/filepath" ) diff --git a/pkg/doc.go b/pkg/doc.go index 8825bd29..a1cdbbb0 100644 --- a/pkg/doc.go +++ b/pkg/doc.go @@ -6,7 +6,7 @@ It is used for rapid development of RESTful APIs, web apps and backend services beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. package main - import "github.com/astaxie/beego" + import "github.com/astaxie/beego/pkg" func main() { beego.Run() diff --git a/pkg/error.go b/pkg/error.go index f268f723..aff984c0 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -23,8 +23,8 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/utils" ) const ( diff --git a/pkg/filter.go b/pkg/filter.go index 9cc6e913..4e212e06 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -14,7 +14,7 @@ package beego -import "github.com/astaxie/beego/context" +import "github.com/astaxie/beego/pkg/context" // FilterFunc defines a filter function which is invoked before the controller handler is executed. type FilterFunc func(*context.Context) diff --git a/pkg/filter_test.go b/pkg/filter_test.go index 4ca4d2b8..3a1bcb07 100644 --- a/pkg/filter_test.go +++ b/pkg/filter_test.go @@ -19,7 +19,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) var FilterUser = func(ctx *context.Context) { diff --git a/pkg/hooks.go b/pkg/hooks.go index 49c42d5a..8c782383 100644 --- a/pkg/hooks.go +++ b/pkg/hooks.go @@ -6,9 +6,9 @@ import ( "net/http" "path/filepath" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/session" ) // register MIME type with content type diff --git a/pkg/log.go b/pkg/log.go index cc4c0f81..785f96d8 100644 --- a/pkg/log.go +++ b/pkg/log.go @@ -17,11 +17,11 @@ package beego import ( "strings" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" ) // Log levels to control the logging output. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. const ( LevelEmergency = iota LevelAlert @@ -34,90 +34,90 @@ const ( ) // BeeLogger references the used application logger. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func SetLevel(l int) { logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func SetLogFuncCall(b bool) { logs.SetLogFuncCall(b) } // SetLogger sets a new logger. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func SetLogger(adaptername string, config string) error { return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Emergency(v ...interface{}) { logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Alert(v ...interface{}) { logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Critical(v ...interface{}) { logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Error(v ...interface{}) { logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Warning(v ...interface{}) { logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Warn(v ...interface{}) { logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Notice(v ...interface{}) { logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Informational(v ...interface{}) { logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Info(v ...interface{}) { logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Debug(v ...interface{}) { logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. +// Deprecated: use github.com/astaxie/beego/pkg/logs instead. func Trace(v ...interface{}) { logs.Trace(generateFmtStr(len(v)), v...) } diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 867ff4cb..8397b3da 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" "github.com/gogo/protobuf/proto" ) diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 2b7b1710..af6a7892 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -12,7 +12,7 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" ) // NewES return a LoggerInterface diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go index 86e2c1b1..6a97aec5 100644 --- a/pkg/metric/prometheus.go +++ b/pkg/metric/prometheus.go @@ -23,8 +23,8 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/logs" ) func PrometheusMiddleWare(next http.Handler) http.Handler { diff --git a/pkg/metric/prometheus_test.go b/pkg/metric/prometheus_test.go index d82a6dec..e04c3285 100644 --- a/pkg/metric/prometheus_test.go +++ b/pkg/metric/prometheus_test.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) func TestPrometheusMiddleWare(t *testing.T) { diff --git a/pkg/migration/ddl.go b/pkg/migration/ddl.go index cd2c1c49..b6d3a2d9 100644 --- a/pkg/migration/ddl.go +++ b/pkg/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" ) // Index struct defines the structure of Index Columns diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index 5ddfd972..c62fd901 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -33,8 +33,8 @@ import ( "strings" "time" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/orm" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/orm" ) // const the data format for the bee generate migration datatype diff --git a/pkg/namespace.go b/pkg/namespace.go index 4952c9d5..bda18f4b 100644 --- a/pkg/namespace.go +++ b/pkg/namespace.go @@ -18,7 +18,7 @@ import ( "net/http" "strings" - beecontext "github.com/astaxie/beego/context" + beecontext "github.com/astaxie/beego/pkg/context" ) type namespaceCond func(*beecontext.Context) bool @@ -97,91 +97,91 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { } // Router same as beego.Rourer -// refer: https://godoc.org/github.com/astaxie/beego#Router +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { n.handlers.Add(rootpath, c, mappingMethods...) return n } // AutoRouter same as beego.AutoRouter -// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego/pkg#AutoRouter func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { n.handlers.AddAuto(c) return n } // AutoPrefix same as beego.AutoPrefix -// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego/pkg#AutoPrefix func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { n.handlers.AddAutoPrefix(prefix, c) return n } // Get same as beego.Get -// refer: https://godoc.org/github.com/astaxie/beego#Get +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Get func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { n.handlers.Get(rootpath, f) return n } // Post same as beego.Post -// refer: https://godoc.org/github.com/astaxie/beego#Post +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Post func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { n.handlers.Post(rootpath, f) return n } // Delete same as beego.Delete -// refer: https://godoc.org/github.com/astaxie/beego#Delete +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Delete func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { n.handlers.Delete(rootpath, f) return n } // Put same as beego.Put -// refer: https://godoc.org/github.com/astaxie/beego#Put +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Put func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { n.handlers.Put(rootpath, f) return n } // Head same as beego.Head -// refer: https://godoc.org/github.com/astaxie/beego#Head +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Head func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { n.handlers.Head(rootpath, f) return n } // Options same as beego.Options -// refer: https://godoc.org/github.com/astaxie/beego#Options +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Options func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { n.handlers.Options(rootpath, f) return n } // Patch same as beego.Patch -// refer: https://godoc.org/github.com/astaxie/beego#Patch +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Patch func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { n.handlers.Patch(rootpath, f) return n } // Any same as beego.Any -// refer: https://godoc.org/github.com/astaxie/beego#Any +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Any func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { n.handlers.Any(rootpath, f) return n } // Handler same as beego.Handler -// refer: https://godoc.org/github.com/astaxie/beego#Handler +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Handler func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { n.handlers.Handler(rootpath, h) return n } // Include add include class -// refer: https://godoc.org/github.com/astaxie/beego#Include +// refer: https://godoc.org/github.com/astaxie/beego/pkg#Include func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { n.handlers.Include(cList...) return n diff --git a/pkg/namespace_test.go b/pkg/namespace_test.go index b3f20dff..bdf33b4f 100644 --- a/pkg/namespace_test.go +++ b/pkg/namespace_test.go @@ -20,7 +20,7 @@ import ( "strconv" "testing" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) func TestNamespaceGet(t *testing.T) { diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index b2f1e693..9a94fb11 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -63,7 +63,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" ) // DebugQueries define the debug diff --git a/pkg/parser.go b/pkg/parser.go index 3a311894..606be190 100644 --- a/pkg/parser.go +++ b/pkg/parser.go @@ -30,16 +30,16 @@ import ( "strings" "unicode" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/context/param" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/utils" ) var globalRouterTemplate = `package {{.routersDir}} import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param"{{.globalimport}} + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context/param"{{.globalimport}} ) func init() { diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go index 10e25f3f..90360aba 100644 --- a/pkg/plugins/apiauth/apiauth.go +++ b/pkg/plugins/apiauth/apiauth.go @@ -65,8 +65,8 @@ import ( "sort" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" ) // AppIDToAppSecret is used to get appsecret throw appid diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go index c478044a..aa548f1a 100644 --- a/pkg/plugins/auth/basic.go +++ b/pkg/plugins/auth/basic.go @@ -40,8 +40,8 @@ import ( "net/http" "strings" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" ) var defaultRealm = "Authorization Required" diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go index 9dc0db76..a375c593 100644 --- a/pkg/plugins/authz/authz.go +++ b/pkg/plugins/authz/authz.go @@ -40,8 +40,8 @@ package authz import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" "github.com/casbin/casbin" "net/http" ) diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go index 49aed84c..53e2652a 100644 --- a/pkg/plugins/authz/authz_test.go +++ b/pkg/plugins/authz/authz_test.go @@ -15,9 +15,9 @@ package authz import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/plugins/auth" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/plugins/auth" "github.com/casbin/casbin" "net/http" "net/http/httptest" diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go index 45c327ab..a4fb3b39 100644 --- a/pkg/plugins/cors/cors.go +++ b/pkg/plugins/cors/cors.go @@ -42,8 +42,8 @@ import ( "strings" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" ) const ( diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go index 34039143..9757a32b 100644 --- a/pkg/plugins/cors/cors_test.go +++ b/pkg/plugins/cors/cors_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" ) // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header diff --git a/pkg/policy.go b/pkg/policy.go index ab23f927..4af240f1 100644 --- a/pkg/policy.go +++ b/pkg/policy.go @@ -17,7 +17,7 @@ package beego import ( "strings" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. diff --git a/pkg/router.go b/pkg/router.go index 6a8ac6f7..995fb767 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -27,11 +27,11 @@ import ( "sync" "time" - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" + beecontext "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/context/param" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/toolbox" + "github.com/astaxie/beego/pkg/utils" ) // default filter execution points diff --git a/pkg/router_test.go b/pkg/router_test.go index 8ec7927a..8a7862f6 100644 --- a/pkg/router_test.go +++ b/pkg/router_test.go @@ -21,8 +21,8 @@ import ( "strings" "testing" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" ) type TestController struct { diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go index 707d042c..227c0bc6 100644 --- a/pkg/session/couchbase/sess_couchbase.go +++ b/pkg/session/couchbase/sess_couchbase.go @@ -39,7 +39,7 @@ import ( couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" ) var couchbpder = &Provider{} diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go index ee81df67..a0988327 100644 --- a/pkg/session/ledis/ledis_session.go +++ b/pkg/session/ledis/ledis_session.go @@ -10,7 +10,7 @@ import ( "github.com/ledisdb/ledisdb/config" "github.com/ledisdb/ledisdb/ledis" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" ) var ( diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go index 85a2d815..6cd8acab 100644 --- a/pkg/session/memcache/sess_memcache.go +++ b/pkg/session/memcache/sess_memcache.go @@ -37,7 +37,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" "github.com/bradfitz/gomemcache/memcache" ) diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go index 301353ab..73738496 100644 --- a/pkg/session/mysql/sess_mysql.go +++ b/pkg/session/mysql/sess_mysql.go @@ -46,7 +46,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" // import mysql driver _ "github.com/go-sql-driver/mysql" ) diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go index 0b8b9645..e6c9ed89 100644 --- a/pkg/session/postgres/sess_postgresql.go +++ b/pkg/session/postgres/sess_postgresql.go @@ -56,7 +56,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" // import postgresql Driver _ "github.com/lib/pq" ) diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index 5c382d61..f569f9dd 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -39,7 +39,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" "github.com/gomodule/redigo/redis" ) diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index 262fa2e3..f7fc7845 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -33,7 +33,7 @@ package redis_cluster import ( - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" rediss "github.com/go-redis/redis" "net/http" "strconv" diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go index 6ecb2977..23bebf2a 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -33,7 +33,7 @@ package redis_sentinel import ( - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" "github.com/go-redis/redis" "net/http" "strconv" diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go index fd4155c6..bd31741f 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" ) func TestRedisSentinel(t *testing.T) { diff --git a/pkg/session/sess_utils.go b/pkg/session/sess_utils.go index 20915bb6..5b70afa8 100644 --- a/pkg/session/sess_utils.go +++ b/pkg/session/sess_utils.go @@ -29,7 +29,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/utils" ) func init() { diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go index de0c6360..1b382954 100644 --- a/pkg/session/ssdb/sess_ssdb.go +++ b/pkg/session/ssdb/sess_ssdb.go @@ -7,7 +7,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/session" + "github.com/astaxie/beego/pkg/session" "github.com/ssdb/gossdb/ssdb" ) diff --git a/pkg/staticfile.go b/pkg/staticfile.go index e26776c5..27e83395 100644 --- a/pkg/staticfile.go +++ b/pkg/staticfile.go @@ -26,8 +26,8 @@ import ( "sync" "time" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" "github.com/hashicorp/golang-lru" ) diff --git a/pkg/template.go b/pkg/template.go index 59875be7..8edd9dc1 100644 --- a/pkg/template.go +++ b/pkg/template.go @@ -27,8 +27,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/utils" ) var ( diff --git a/pkg/template_test.go b/pkg/template_test.go index 287faadc..590a7bd6 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -16,7 +16,7 @@ package beego import ( "bytes" - "github.com/astaxie/beego/testdata" + "github.com/astaxie/beego/pkg/testdata" "github.com/elazarl/go-bindata-assetfs" "net/http" "os" diff --git a/pkg/testing/assertions.go b/pkg/testing/assertions.go deleted file mode 100644 index 96c5d4dd..00000000 --- a/pkg/testing/assertions.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testing diff --git a/pkg/testing/client.go b/pkg/testing/client.go index c3737e9c..0062b857 100644 --- a/pkg/testing/client.go +++ b/pkg/testing/client.go @@ -15,8 +15,8 @@ package testing import ( - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/httplib" + "github.com/astaxie/beego/pkg/config" + "github.com/astaxie/beego/pkg/httplib" ) var port = "" diff --git a/pkg/tree.go b/pkg/tree.go index 9e53003b..785ba6a6 100644 --- a/pkg/tree.go +++ b/pkg/tree.go @@ -19,8 +19,8 @@ import ( "regexp" "strings" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/utils" ) var ( diff --git a/pkg/tree_test.go b/pkg/tree_test.go index d412a348..8758e0c0 100644 --- a/pkg/tree_test.go +++ b/pkg/tree_test.go @@ -18,7 +18,7 @@ import ( "strings" "testing" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) type testinfo struct { diff --git a/pkg/utils/captcha/captcha.go b/pkg/utils/captcha/captcha.go index 42ac70d3..f0c37058 100644 --- a/pkg/utils/captcha/captcha.go +++ b/pkg/utils/captcha/captcha.go @@ -66,11 +66,11 @@ import ( "strings" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/utils" ) var ( diff --git a/pkg/utils/captcha/image_test.go b/pkg/utils/captcha/image_test.go index 5e35b7f7..73d3361b 100644 --- a/pkg/utils/captcha/image_test.go +++ b/pkg/utils/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/utils" + "github.com/astaxie/beego/pkg/utils" ) type byteCounter struct { diff --git a/pkg/utils/pagination/controller.go b/pkg/utils/pagination/controller.go index 2f022d0c..b5b09a2f 100644 --- a/pkg/utils/pagination/controller.go +++ b/pkg/utils/pagination/controller.go @@ -15,7 +15,7 @@ package pagination import ( - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/context" ) // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). diff --git a/pkg/utils/pagination/doc.go b/pkg/utils/pagination/doc.go index 9abc6d78..718f5e7a 100644 --- a/pkg/utils/pagination/doc.go +++ b/pkg/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/astaxie/beego/utils/pagination" + import "github.com/astaxie/beego/pkg/utils/pagination" type PostsController struct { beego.Controller diff --git a/pkg/validation/validators.go b/pkg/validation/validators.go index 38b6f1aa..87c83ccd 100644 --- a/pkg/validation/validators.go +++ b/pkg/validation/validators.go @@ -16,7 +16,7 @@ package validation import ( "fmt" - "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/pkg/logs" "reflect" "regexp" "strings" From 22b8cae73be75d009151222bc2788139374074d9 Mon Sep 17 00:00:00 2001 From: wangle <285273592@qq.com> Date: Wed, 29 Jul 2020 23:23:02 +0800 Subject: [PATCH 050/207] Add the operator(>,>=,<,<=,=,!=) of orm eg: qs.Filter("counts__>=","20") qs.Filter("counts__!=","20") --- orm/db.go | 6 ++++++ orm/db_mysql.go | 6 ++++++ orm/db_oracle.go | 5 +++++ orm/db_postgres.go | 6 ++++++ orm/db_sqlite.go | 6 ++++++ 5 files changed, 29 insertions(+) diff --git a/orm/db.go b/orm/db.go index 9a1827e8..5d175bf1 100644 --- a/orm/db.go +++ b/orm/db.go @@ -49,6 +49,12 @@ var ( "eq": true, "nq": true, "ne": true, + ">": true, + ">=": true, + "<": true, + "<=": true, + "=": true, + "!=": true, "startswith": true, "endswith": true, "istartswith": true, diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 6e99058e..36f6f566 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -29,11 +29,17 @@ var mysqlOperators = map[string]string{ // "regex": "REGEXP BINARY ?", // "iregex": "REGEXP ?", "gt": "> ?", + ">": "> ?", "gte": ">= ?", + ">=": ">= ?", "lt": "< ?", + "<": "< ?", "lte": "<= ?", + "<=": "<= ?", "eq": "= ?", + "=": "= ?", "ne": "!= ?", + "!=": "!= ?", "startswith": "LIKE BINARY ?", "endswith": "LIKE BINARY ?", "istartswith": "LIKE ?", diff --git a/orm/db_oracle.go b/orm/db_oracle.go index 5d121f83..ed2ec74c 100644 --- a/orm/db_oracle.go +++ b/orm/db_oracle.go @@ -22,10 +22,15 @@ import ( // oracle operators. var oracleOperators = map[string]string{ "exact": "= ?", + "=": "= ?", "gt": "> ?", + ">": "> ?", "gte": ">= ?", + ">=": ">= ?", "lt": "< ?", + "<": "< ?", "lte": "<= ?", + "<=": "<= ?", "//iendswith": "LIKE ?", } diff --git a/orm/db_postgres.go b/orm/db_postgres.go index c488fb38..7eb88d7a 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -26,11 +26,17 @@ var postgresOperators = map[string]string{ "contains": "LIKE ?", "icontains": "LIKE UPPER(?)", "gt": "> ?", + ">": "> ?", "gte": ">= ?", + ">=": ">= ?", "lt": "< ?", + "<": "< ?", "lte": "<= ?", + "<=": "<= ?", "eq": "= ?", + "=": "= ?", "ne": "!= ?", + "!=": "!= ?", "startswith": "LIKE ?", "endswith": "LIKE ?", "istartswith": "LIKE UPPER(?)", diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index 1d62ee34..bd9f5d3b 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -28,11 +28,17 @@ var sqliteOperators = map[string]string{ "contains": "LIKE ? ESCAPE '\\'", "icontains": "LIKE ? ESCAPE '\\'", "gt": "> ?", + ">": "> ?", "gte": ">= ?", + ">=": ">= ?", "lt": "< ?", + "<": "< ?", "lte": "<= ?", + "<=": "<= ?", "eq": "= ?", + "=": "= ?", "ne": "!= ?", + "!=": "!= ?", "startswith": "LIKE ? ESCAPE '\\'", "endswith": "LIKE ? ESCAPE '\\'", "istartswith": "LIKE ? ESCAPE '\\'", From 15e11931fcd85128cada95daba8b72b789d34a97 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 10:53:30 +0800 Subject: [PATCH 051/207] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=20BConfig.L?= =?UTF-8?q?isten.ClientAuth=20=E5=AD=97=E6=AE=B5=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E5=A4=84=E7=90=86=E3=80=82=E5=BD=93=E6=8C=87=E5=AE=9A?= =?UTF-8?q?=E4=BA=86=E8=AF=A5=E9=85=8D=E7=BD=AE=E6=97=B6=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E9=85=8D=E7=BD=AE=E7=9A=84=E5=80=BC=E6=9D=A5=E4=BD=9C?= =?UTF-8?q?=E4=B8=BA=E9=AA=8C=E8=AF=81=E5=AE=A2=E6=88=B7=E7=AB=AF=E7=9A=84?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E3=80=82=E5=A6=82=E6=9E=9C=E6=B2=A1=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=EF=BC=8C=E4=BD=BF=E7=94=A8=E9=BB=98=E8=AE=A4=E5=80=BC?= =?UTF-8?q?=20tls.RequireAndVerifyClientCert?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/app.go b/app.go index f3fe6f7b..5e595bb6 100644 --- a/app.go +++ b/app.go @@ -195,10 +195,15 @@ func (app *App) Run(mws ...MiddleWare) { return } pool.AppendCertsFromPEM(data) - app.Server.TLSConfig = &tls.Config{ + tlsConfig := tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, } + if string(BConfig.Listen.ClientAuth) != "" { + tslConfig.ClientAuth = BConfig.Listen.ClientAuth + } else { + tslConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + app.Server.TLSConfig = &tslConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From 513a4afff14c056f1fe5a844d8a60123987491f3 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 10:59:32 +0800 Subject: [PATCH 052/207] =?UTF-8?q?=E5=AF=B9=20Listen=20=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E4=BD=93=E5=A2=9E=E5=8A=A0=20ClientAuth=20=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对 Listen 结构体增加 ClientAuth 字段,赋予默认配置对象该字段值为 tls.VerifyClientCertIfGiven,与原代码逻辑的默认值保持一致 --- config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config.go b/config.go index 92aa3bbd..fef6c482 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ import ( "reflect" "runtime" "strings" + "crypto/tls" "github.com/astaxie/beego/config" "github.com/astaxie/beego/context" @@ -65,6 +66,7 @@ type Listen struct { HTTPSCertFile string HTTPSKeyFile string TrustCaFile string + ClientAuth tls.ClientAuthType EnableAdmin bool AdminAddr string AdminPort int @@ -234,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: tls.VerifyClientCertIfGiven, }, WebConfig: WebConfig{ AutoRender: true, From 9d23e5a3fb23df1ac647660c13f566fa17e81c1e Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:03:32 +0800 Subject: [PATCH 053/207] =?UTF-8?q?=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=86=99=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app.go b/app.go index 5e595bb6..85fe7e5d 100644 --- a/app.go +++ b/app.go @@ -197,11 +197,10 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) tlsConfig := tls.Config{ ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, } if string(BConfig.Listen.ClientAuth) != "" { tslConfig.ClientAuth = BConfig.Listen.ClientAuth - } else { - tslConfig.ClientAuth = tls.RequireAndVerifyClientCert } app.Server.TLSConfig = &tslConfig } From c46ba862157b9b6d5d39561985fa7f4d8a8bdd18 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:18:14 +0800 Subject: [PATCH 054/207] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=AC=94=E8=AF=AF?= =?UTF-8?q?=E4=BA=A7=E7=94=9F=E7=9A=84=E6=8B=BC=E5=86=99=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.go b/app.go index 85fe7e5d..20af4ce8 100644 --- a/app.go +++ b/app.go @@ -200,7 +200,7 @@ func (app *App) Run(mws ...MiddleWare) { ClientAuth: tls.RequireAndVerifyClientCert, } if string(BConfig.Listen.ClientAuth) != "" { - tslConfig.ClientAuth = BConfig.Listen.ClientAuth + tlsConfig.ClientAuth = BConfig.Listen.ClientAuth } app.Server.TLSConfig = &tslConfig } From 0815e77f9af9336b17a50d6a4226aad9a1969731 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:20:22 +0800 Subject: [PATCH 055/207] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=AC=94=E8=AF=AF?= =?UTF-8?q?=E4=BA=A7=E7=94=9F=E7=9A=84=E6=8B=BC=E5=86=99=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.go b/app.go index 20af4ce8..49cef256 100644 --- a/app.go +++ b/app.go @@ -202,7 +202,7 @@ func (app *App) Run(mws ...MiddleWare) { if string(BConfig.Listen.ClientAuth) != "" { tlsConfig.ClientAuth = BConfig.Listen.ClientAuth } - app.Server.TLSConfig = &tslConfig + app.Server.TLSConfig = &tlsConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From 520380416557c76a59a2c30398e027c27ad70d36 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 14:46:17 +0800 Subject: [PATCH 056/207] =?UTF-8?q?=E8=B0=83=E6=95=B4=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E4=B8=AD=E7=9A=84=20ClientAuth=20=E5=80=BC?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E4=B9=8B=E4=B8=8E=E5=8E=9F=E6=9D=A5=E7=9A=84?= =?UTF-8?q?=E8=A1=8C=E4=B8=BA=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index fef6c482..0c995293 100644 --- a/config.go +++ b/config.go @@ -236,7 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, - ClientAuth: tls.VerifyClientCertIfGiven, + ClientAuth: tls.RequireAndVerifyClientCert, }, WebConfig: WebConfig{ AutoRender: true, From 7831638f3793d1b3058ffa0cbc1e8d0c818bd3e8 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 14:48:46 +0800 Subject: [PATCH 057/207] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=9A=84=E6=9D=A1=E4=BB=B6=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/app.go b/app.go index 49cef256..3dee8999 100644 --- a/app.go +++ b/app.go @@ -195,14 +195,10 @@ func (app *App) Run(mws ...MiddleWare) { return } pool.AppendCertsFromPEM(data) - tlsConfig := tls.Config{ + app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: BConfig.Listen.ClientAuth, } - if string(BConfig.Listen.ClientAuth) != "" { - tlsConfig.ClientAuth = BConfig.Listen.ClientAuth - } - app.Server.TLSConfig = &tlsConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From 28e6b3b92450b0ca0e9c1342d400f8810e4d5e5a Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Mon, 3 Aug 2020 13:31:49 +0200 Subject: [PATCH 058/207] Add error to SessionExist interface Implement changed interface for all default providers as well and change tests accordingly --- session/couchbase/sess_couchbase.go | 6 +- session/ledis/ledis_session.go | 4 +- session/memcache/sess_memcache.go | 8 +-- session/mysql/sess_mysql.go | 10 ++- session/postgres/sess_postgresql.go | 10 ++- session/redis/sess_redis.go | 6 +- session/redis_cluster/redis_cluster.go | 6 +- session/redis_sentinel/sess_redis_sentinel.go | 6 +- session/sess_cookie.go | 4 +- session/sess_file.go | 9 ++- session/sess_file_test.go | 62 +++++++++++++++---- session/sess_mem.go | 6 +- session/session.go | 12 +++- session/ssdb/sess_ssdb.go | 8 +-- 14 files changed, 109 insertions(+), 48 deletions(-) diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go index 707d042c..46ab07ab 100644 --- a/session/couchbase/sess_couchbase.go +++ b/session/couchbase/sess_couchbase.go @@ -179,16 +179,16 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) bool { +func (cp *Provider) SessionExist(sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() var doc []byte if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false + return false, err } - return true + return true, nil } // SessionRegenerate remove oldsid and use sid to generate new session diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index ee81df67..4f578eac 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -132,9 +132,9 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) bool { +func (lp *Provider) SessionExist(sid string) (bool, error) { count, _ := c.Exists([]byte(sid)) - return count != 0 + return count != 0, nil } // SessionRegenerate generate new sid for ledis session diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index 85a2d815..e76eb8a5 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -149,16 +149,16 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) bool { +func (rp *MemProvider) SessionExist(sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { - return false + return false, err } } if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for memcache session diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 301353ab..9f9547a7 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -164,13 +164,19 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for mysql session diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index 0b8b9645..d8a1e6de 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -178,13 +178,19 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for postgresql session diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index 5c382d61..439b14cb 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -211,14 +211,14 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist.Get() defer c.Close() if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis session diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go index 262fa2e3..d4e28327 100644 --- a/session/redis_cluster/redis_cluster.go +++ b/session/redis_cluster/redis_cluster.go @@ -176,12 +176,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis_cluster session diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go index 6ecb2977..eead7a74 100644 --- a/session/redis_sentinel/sess_redis_sentinel.go +++ b/session/redis_sentinel/sess_redis_sentinel.go @@ -189,12 +189,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis_sentinel session diff --git a/session/sess_cookie.go b/session/sess_cookie.go index 6ad5debc..30a7032e 100644 --- a/session/sess_cookie.go +++ b/session/sess_cookie.go @@ -147,8 +147,8 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) bool { - return true +func (pder *CookieProvider) SessionExist(sid string) (bool, error) { + return true, nil } // SessionRegenerate Implement method, no used. diff --git a/session/sess_file.go b/session/sess_file.go index 47ad54a7..3345d5d0 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -176,17 +176,20 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) bool { +func (fp *FileProvider) SessionExist(sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() if len(sid) < 2 { SLogger.Println("min length of session id is 2", sid) - return false + return false, nil } _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return err == nil + if err != nil { + return false, nil + } + return true, nil } // SessionDestroy Remove all files in this save path diff --git a/session/sess_file_test.go b/session/sess_file_test.go index 021c43fc..1e155f91 100644 --- a/session/sess_file_test.go +++ b/session/sess_file_test.go @@ -56,16 +56,24 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - if fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil{ + t.Error(err) + } + if exists { t.Error() } - _, err := fp.SessionRead(sid) + _, err = fp.SessionRead(sid) if err != nil { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } } @@ -79,15 +87,27 @@ func TestFileProvider_SessionExist2(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - if fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } - if fp.SessionExist("") { + exists, err = fp.SessionExist("") + if err != nil { + t.Error(err) + } + if exists { t.Error() } - if fp.SessionExist("1") { + exists, err = fp.SessionExist("1") + if err != nil { + t.Error(err) + } + if exists { t.Error() } } @@ -171,7 +191,11 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } @@ -180,11 +204,19 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - if fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } - if !fp.SessionExist(sidNew) { + exists, err = fp.SessionExist(sidNew) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } } @@ -203,7 +235,11 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } @@ -212,7 +248,11 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - if fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } } diff --git a/session/sess_mem.go b/session/sess_mem.go index 64d8b056..bd69ff80 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -109,13 +109,13 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) bool { +func (pder *MemProvider) SessionExist(sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { - return true + return true, nil } - return false + return false, nil } // SessionRegenerate generate new sid for session store in memory session diff --git a/session/session.go b/session/session.go index eb85360a..92e35de4 100644 --- a/session/session.go +++ b/session/session.go @@ -56,7 +56,7 @@ type Store interface { type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (Store, error) - SessionExist(sid string) bool + SessionExist(sid string) (bool, error) SessionRegenerate(oldsid, sid string) (Store, error) SessionDestroy(sid string) error SessionAll() int //get all active session @@ -211,8 +211,14 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - if sid != "" && manager.provider.SessionExist(sid) { - return manager.provider.SessionRead(sid) + if sid != "" { + exists, err := manager.provider.SessionExist(sid) + if err != nil { + return nil, err + } + if exists { + return manager.provider.SessionRead(sid) + } } // Generate a new session diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go index de0c6360..9b9eee94 100644 --- a/session/ssdb/sess_ssdb.go +++ b/session/ssdb/sess_ssdb.go @@ -68,7 +68,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) bool { +func (p *Provider) SessionExist(sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { panic(err) @@ -76,12 +76,12 @@ func (p *Provider) SessionExist(sid string) bool { } value, err := p.client.Get(sid) if err != nil { - panic(err) + return false, err } if value == nil || len(value.(string)) == 0 { - return false + return false, nil } - return true + return true, nil } // SessionRegenerate regenerate session with new sid and delete oldsid From 6f5c5bd3a65561db56aca26eae4a50abef8fa5b4 Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Mon, 3 Aug 2020 13:33:30 +0200 Subject: [PATCH 059/207] Change interface in session README --- session/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session/README.md b/session/README.md index 6d0a297e..a5c3bd6d 100644 --- a/session/README.md +++ b/session/README.md @@ -101,7 +101,7 @@ Maybe you will find the **memory** provider is a good example. type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) bool + SessionExist(sid string) (bool, error) SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionDestroy(sid string) error SessionAll() int //get all active session From a0d1c42daca7af6cbf3a6c73a89793a06cdbd4c7 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 3 Aug 2020 21:03:08 +0800 Subject: [PATCH 060/207] XSRF add secure and http only flag --- context/context.go | 2 +- context/context_test.go | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/context/context.go b/context/context.go index de248ed2..7c161ac0 100644 --- a/context/context.go +++ b/context/context.go @@ -150,7 +150,7 @@ func (ctx *Context) XSRFToken(key string, expire int64) string { token, ok := ctx.GetSecureCookie(key, "_xsrf") if !ok { token = string(utils.RandomCreateBytes(32)) - ctx.SetSecureCookie(key, "_xsrf", token, expire) + ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "", true, true) } ctx._xsrfToken = token } diff --git a/context/context_test.go b/context/context_test.go index 7c0535e0..e81e8191 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -17,7 +17,10 @@ package context import ( "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestXsrfReset_01(t *testing.T) { @@ -44,4 +47,8 @@ func TestXsrfReset_01(t *testing.T) { if token == c._xsrfToken { t.FailNow() } + + ck := c.ResponseWriter.Header().Get("Set-Cookie") + assert.True(t, strings.Contains(ck, "Secure")) + assert.True(t, strings.Contains(ck, "HttpOnly")) } From 79ffef90e37ac2692fe694b8ebae01ce2fbe2794 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 31 Jul 2020 21:43:11 +0800 Subject: [PATCH 061/207] support filter chain --- .gitignore | 3 + admin_test.go | 18 +++- pkg/app.go | 7 ++ pkg/controller_test.go | 8 +- pkg/filter.go | 62 ++++++++++++- pkg/filter_chain_test.go | 49 +++++++++++ pkg/router.go | 185 ++++++++++++++++++++------------------- pkg/template_test.go | 37 +++++--- 8 files changed, 261 insertions(+), 108 deletions(-) create mode 100644 pkg/filter_chain_test.go diff --git a/.gitignore b/.gitignore index e1b65291..43adebd5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ *.swp *.swo beego.iml + +_beeTmp +_beeTmp2 diff --git a/admin_test.go b/admin_test.go index 3f3612e4..205c76c2 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,10 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/toolbox" ) @@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { t.Errorf("invalid response map length: got %d want %d", len(decodedResponseBody), len(expectedResponseBody)) } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) - if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { - t.Errorf("handler returned unexpected body: got %v want %v", - decodedResponseBody, expectedResponseBody) + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] } + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) + } diff --git a/pkg/app.go b/pkg/app.go index eb672b1f..d94d56b5 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -495,3 +495,10 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) return BeeApp } + +// InsertFilterChain adds a FilterFunc built by filterChain. +// This filter will be executed before all filters. +func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) + return BeeApp +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go index f51cc109..e30f7211 100644 --- a/pkg/controller_test.go +++ b/pkg/controller_test.go @@ -19,6 +19,8 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/context" "os" "path/filepath" @@ -125,8 +127,10 @@ func TestGetUint64(t *testing.T) { } func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths") defer os.RemoveAll(dir1) defer os.RemoveAll(dir2) diff --git a/pkg/filter.go b/pkg/filter.go index 4e212e06..543d7901 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -14,10 +14,19 @@ package beego -import "github.com/astaxie/beego/pkg/context" +import ( + "strings" + + "github.com/astaxie/beego/pkg/context" +) + +// FilterChain is different from pure FilterFunc +// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned +// And all those FilterChain will be invoked before other FilterFunc +type FilterChain func(next FilterFunc) FilterFunc // FilterFunc defines a filter function which is invoked before the controller handler is executed. -type FilterFunc func(*context.Context) +type FilterFunc func(ctx *context.Context) // FilterRouter defines a filter operation which is invoked before the controller handler is executed. // It can match the URL against a pattern, and execute a filter function @@ -30,6 +39,55 @@ type FilterRouter struct { resetParams bool } +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !routerCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return mr +} + +// filter will check whether we need to execute the filter logic +// return (started, done) +func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) { + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + if f.resetParams { + preFilterParams = ctx.Input.Params() + } + if ok := f.ValidRouter(urlPath, ctx); ok { + f.filterFunc(ctx) + if f.resetParams { + ctx.Input.ResetParams() + for k, v := range preFilterParams { + ctx.Input.SetParam(k, v) + } + } + } + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + return false, false +} + // ValidRouter checks if the current request is matched by this filter. // If the request is matched, the values of the URL parameters defined // by the filter pattern are also returned. diff --git a/pkg/filter_chain_test.go b/pkg/filter_chain_test.go new file mode 100644 index 00000000..42397a60 --- /dev/null +++ b/pkg/filter_chain_test.go @@ -0,0 +1,49 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/context" +) + +func TestControllerRegister_InsertFilterChain(t *testing.T) { + + InsertFilterChain("/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", "filter-chain") + next(ctx) + } + }) + + ns := NewNamespace("/chain") + + ns.Get("/*", func(ctx *context.Context) { + ctx.Output.Body([]byte("hello")) + }) + + + r, _ := http.NewRequest("GET", "/chain/user", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.ServeHTTP(w, r) + + assert.Equal(t, "filter-chain", w.Header().Get("filter")) +} diff --git a/pkg/router.go b/pkg/router.go index 995fb767..b0c23003 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -134,11 +134,14 @@ type ControllerRegister struct { enableFilter bool filters [FinishRouter + 1][]*FilterRouter pool sync.Pool + + // the filter created by FilterChain + chainRoot *FilterRouter } // NewControllerRegister returns a new ControllerRegister. func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ + res := &ControllerRegister{ routers: make(map[string]*Tree), policies: make(map[string]*Tree), pool: sync.Pool{ @@ -147,6 +150,8 @@ func NewControllerRegister() *ControllerRegister { }, }, } + res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + return res } // Add controller handler and pattern rules to ControllerRegister. @@ -489,27 +494,28 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) + mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) return p.insertFilterRouter(pos, mr) } +// InsertFilterChain is similar to InsertFilter, +// but it will using chainRoot.filterFunc as input to build a new filterFunc +// for example, assume that chainRoot is funcA +// and we add new FilterChain +// fc := func(next) { +// return func(ctx) { +// // do something +// next(ctx) +// // do something +// } +// } +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) { + root := p.chainRoot + filterFunc := chain(root.filterFunc) + p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) +} + + // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { if pos < BeforeStatic || pos > FinishRouter { @@ -668,23 +674,9 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { var preFilterParams map[string]string for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true + b, done := filterR.filter(context, urlPath, preFilterParams) + if done { + return b } } return false @@ -692,7 +684,20 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str // Implement http.Handler interface. func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + + ctx := p.GetContext() + + ctx.Reset(rw, r) + defer p.GiveBackContext(ctx) + + var preFilterParams map[string]string + p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams) +} + +func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { startTime := time.Now() + r := ctx.Request + rw := ctx.ResponseWriter.ResponseWriter var ( runRouter reflect.Type findRouter bool @@ -701,108 +706,100 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) routerInfo *ControllerInfo isRunnable bool ) - context := p.GetContext() - context.Reset(rw, r) - - defer p.GiveBackContext(context) if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) + defer BConfig.RecoverFunc(ctx) } - context.Output.EnableGzip = BConfig.EnableGzip + ctx.Output.EnableGzip = BConfig.EnableGzip if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) + ctx.Output.Header("Server", BConfig.ServerName) } - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } + urlPath := p.getUrlPath(ctx) // filter wrong http method if !HTTPMETHOD[r.Method] { - exception("405", context) + exception("405", ctx) goto Admin } // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) { goto Admin } - serverStaticRouter(context) + serverStaticRouter(ctx) - if context.ResponseWriter.Started { + if ctx.ResponseWriter.Started { findRouter = true goto Admin } if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { + if BConfig.CopyRequestBody && !ctx.Input.IsUpload() { // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > BConfig.MaxMemory { logs.Error(errors.New("payload too large")) - exception("413", context) + exception("413", ctx) goto Admin } - context.Input.CopyBody(BConfig.MaxMemory) + ctx.Input.CopyBody(BConfig.MaxMemory) } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory) } // session init if BConfig.WebConfig.Session.SessionOn { var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) - exception("503", context) + exception("503", ctx) goto Admin } defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) + if ctx.Input.CruSession != nil { + ctx.Input.CruSession.SessionRelease(rw) } }() } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) { goto Admin } // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { + if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" { findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController + runMethod = ctx.Input.RunMethod + runRouter = ctx.Input.RunController } else { - routerInfo, findRouter = p.FindRouter(context) + routerInfo, findRouter = p.FindRouter(ctx) } // if no matches to url, throw a not found exception if !findRouter { - exception("404", context) + exception("404", ctx) goto Admin } - if splat := context.Input.Param(":splat"); splat != "" { + if splat := ctx.Input.Param(":splat"); splat != "" { for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) + ctx.Input.SetParam(strconv.Itoa(k), v) } } if routerInfo != nil { // store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) + ctx.Input.SetData("RouterPattern", routerInfo.pattern) } // execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) { goto Admin } // check policies - if p.execPolicy(context, urlPath) { + if p.execPolicy(ctx, urlPath) { goto Admin } @@ -810,22 +807,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { isRunnable = true - routerInfo.runFunction(context) + routerInfo.runFunction(ctx) } else { - exception("405", context) + exception("405", ctx) goto Admin } } else if routerInfo.routerType == routerTypeHandler { isRunnable = true - routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request) } else { runRouter = routerInfo.controllerType methodParams = routerInfo.methodParams method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut { method = http.MethodPut } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete { method = http.MethodDelete } if m, ok := routerInfo.methods[method]; ok { @@ -854,7 +851,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) + execController.Init(ctx, runRouter.Name(), runMethod, execController) // call prepare function execController.Prepare() @@ -863,14 +860,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if BConfig.WebConfig.EnableXSRF { execController.XSRFToken() if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) { execController.CheckXSRFCookie() } } execController.URLMapping() - if !context.ResponseWriter.Started { + if !ctx.ResponseWriter.Started { // exec main logic switch runMethod { case http.MethodGet: @@ -893,18 +890,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !execController.HandlerFunc(runMethod) { vc := reflect.ValueOf(execController) method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) + in := param.ConvertParams(methodParams, method.Type(), ctx) out := method.Call(in) // For backward compatibility we only handle response if we had incoming methodParams if methodParams != nil { - p.handleParamResponse(context, execController, out) + p.handleParamResponse(ctx, execController, out) } } } // render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { + if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { logs.Error(err) @@ -918,26 +915,26 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) { goto Admin } - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) { goto Admin } Admin: // admin module record QPS - statusCode := context.ResponseWriter.Status + statusCode := ctx.ResponseWriter.Status if statusCode == 0 { statusCode = 200 } - LogAccess(context, &startTime, statusCode) + LogAccess(ctx, &startTime, statusCode) timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur + ctx.ResponseWriter.Elapsed = timeDur if BConfig.Listen.EnableAdmin { pattern := "" if routerInfo != nil { @@ -956,7 +953,7 @@ Admin: if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { match := map[bool]string{true: "match", false: "nomatch"} devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), + ctx.Input.IP(), logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), timeDur.String(), match[findRouter], @@ -969,11 +966,19 @@ Admin: logs.Debug(devInfo) } // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) } } +func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string { + urlPath := ctx.Request.URL.Path + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + return urlPath +} + func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { // looping in reverse order for the case when both error and value are returned and error sets the response status code for i := len(results) - 1; i >= 0; i-- { diff --git a/pkg/template_test.go b/pkg/template_test.go index 590a7bd6..af948190 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -16,12 +16,15 @@ package beego import ( "bytes" - "github.com/astaxie/beego/pkg/testdata" - "github.com/elazarl/go-bindata-assetfs" "net/http" "os" "path/filepath" "testing" + + "github.com/elazarl/go-bindata-assetfs" + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/testdata" ) var header = `{{define "header"}} @@ -46,7 +49,9 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", @@ -56,7 +61,8 @@ func TestTemplate(t *testing.T) { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -107,7 +113,9 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp") //Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { @@ -218,7 +226,10 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", @@ -226,17 +237,22 @@ func TestTemplateLayout(t *testing.T) { if err := os.MkdirAll(dir, 0777); err != nil { t.Fatal(err) } + for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { if k == 0 { - f.WriteString(add) + _, writeErr := f.WriteString(add) + assert.Nil(t, writeErr) } else if k == 1 { - f.WriteString(layoutBlog) + _, writeErr := f.WriteString(layoutBlog) + assert.Nil(t, writeErr) } - f.Close() + clErr := f.Close() + assert.Nil(t, clErr) } } if err := AddViewPath(dir); err != nil { @@ -247,6 +263,7 @@ func TestTemplateLayout(t *testing.T) { t.Fatalf("should be 2 but got %v", len(beeTemplates)) } out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { t.Fatal(err) } From 12b984861dcae1903cdff40844b85a801614e1ca Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Tue, 4 Aug 2020 19:21:26 +0800 Subject: [PATCH 062/207] fix CI fail for connection log test --- logs/conn_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/logs/conn_test.go b/logs/conn_test.go index bb377d41..32f96302 100644 --- a/logs/conn_test.go +++ b/logs/conn_test.go @@ -18,6 +18,7 @@ import ( "net" "os" "testing" + "time" ) // ConnTCPListener takes a TCP listener and accepts n TCP connections @@ -73,7 +74,7 @@ func TestReconnect(t *testing.T) { select { case second := <-newConns: second.Close() - default: + case <-time.After(5 * time.Second): t.Error("Did not reconnect") } } From 1961c1e4413b7aeb436b3fb276d2e1707c647282 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 4 Aug 2020 21:37:46 +0800 Subject: [PATCH 063/207] Revert "fix CI fail for connection log test" --- logs/conn_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/logs/conn_test.go b/logs/conn_test.go index 32f96302..bb377d41 100644 --- a/logs/conn_test.go +++ b/logs/conn_test.go @@ -18,7 +18,6 @@ import ( "net" "os" "testing" - "time" ) // ConnTCPListener takes a TCP listener and accepts n TCP connections @@ -74,7 +73,7 @@ func TestReconnect(t *testing.T) { select { case second := <-newConns: second.Close() - case <-time.After(5 * time.Second): + default: t.Error("Did not reconnect") } } From 6c6cf91741d6e4ab751132cdb9c75808c2ca004e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 4 Aug 2020 22:16:55 +0800 Subject: [PATCH 064/207] Support prometheus and opentracing filter --- go.mod | 5 +- go.sum | 47 +++++++++++++-- metric/prometheus.go | 2 + pkg/web/doc.go | 16 ++++++ pkg/web/filter/opentracing/filter.go | 56 ++++++++++++++++++ pkg/web/filter/opentracing/filter_test.go | 47 +++++++++++++++ .../filter/prometheus/filter.go} | 57 +++++++------------ .../filter/prometheus/filter_test.go} | 34 ++++++----- 8 files changed, 206 insertions(+), 58 deletions(-) create mode 100644 pkg/web/doc.go create mode 100644 pkg/web/filter/opentracing/filter.go create mode 100644 pkg/web/filter/opentracing/filter_test.go rename pkg/{metric/prometheus.go => web/filter/prometheus/filter.go} (57%) rename pkg/{metric/prometheus_test.go => web/filter/prometheus/filter_test.go} (51%) diff --git a/go.mod b/go.mod index adca28ad..a6c27488 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect github.com/elastic/go-elasticsearch/v6 v6.8.5 github.com/elazarl/go-bindata-assetfs v1.0.0 + github.com/go-kit/kit v0.9.0 github.com/go-redis/redis v6.14.2+incompatible github.com/go-sql-driver/mysql v1.5.0 github.com/gogo/protobuf v1.1.1 @@ -21,6 +22,7 @@ require ( github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/opentracing/opentracing-go v1.2.0 github.com/pelletier/go-toml v1.2.0 // indirect github.com/prometheus/client_golang v1.7.0 github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 @@ -29,7 +31,8 @@ require ( github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 - golang.org/x/tools v0.0.0-20200117065230-39095c1d176c + golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect + google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 ) diff --git a/go.sum b/go.sum index c7b861ac..12b76333 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= @@ -20,10 +21,13 @@ github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVx github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ= github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d h1:OMrhQqj1QCyDT2sxHCDjE+k8aMdn2ngTCGG7g4wrdLo= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= @@ -41,25 +45,34 @@ github.com/elastic/go-elasticsearch/v6 v6.8.5 h1:U2HtkBseC1FNBmDr0TR2tKltL6FxoY+ github.com/elastic/go-elasticsearch/v6 v6.8.5/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox93HoX8utj1kxD9aFUcI= github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/glendc/gopher-json v0.0.0-20170414221815-dc4743023d0c/go.mod h1:Gja1A+xZ9BoviGJNA2E9vFkPjjsl+CoJxSXiQM1UXtw= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0 h1:MP4Eh7ZCb31lleYCFuwm0oe4/YGak+5l1vA2NOE80nA= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0= github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAfQBdIfsiHJkepbYsDr+VY3g9/o= github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -72,6 +85,7 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= @@ -85,6 +99,7 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -110,11 +125,12 @@ github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterh/liner v1.0.1-0.20171122030339-3681c2a91233/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= -github.com/pingcap/tidb v2.0.11+incompatible/go.mod h1:I8C6jrPINP2rrVunTRd7C9fRRhQrtR43S1/CL5ix/yQ= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= @@ -127,6 +143,7 @@ github.com/prometheus/client_golang v1.7.0 h1:wCi7urQOGBsYcQROHqpUUX4ct84xp40t9R github.com/prometheus/client_golang v1.7.0/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= @@ -164,19 +181,28 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -189,10 +215,21 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -215,3 +252,5 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/metric/prometheus.go b/metric/prometheus.go index 86e2c1b1..215896bd 100644 --- a/metric/prometheus.go +++ b/metric/prometheus.go @@ -27,6 +27,8 @@ import ( "github.com/astaxie/beego/logs" ) +// Deprecated: we will removed this function in 2.1.0 +// please use pkg/web/filter/prometheus#FilterChain func PrometheusMiddleWare(next http.Handler) http.Handler { summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", diff --git a/pkg/web/doc.go b/pkg/web/doc.go new file mode 100644 index 00000000..2001f4ca --- /dev/null +++ b/pkg/web/doc.go @@ -0,0 +1,16 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// we will move all web related codes here +package web diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go new file mode 100644 index 00000000..ef6521f7 --- /dev/null +++ b/pkg/web/filter/opentracing/filter.go @@ -0,0 +1,56 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "github.com/opentracing/opentracing-go" + + beego "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" +) + +// FilterChainBuilder provides an extension point that we can support more configurations if necessary +type FilterChainBuilder struct { + // CustomSpanFunc makes users to custom the span. + CustomSpanFunc func(span opentracing.Span, ctx *context.Context) +} + + +func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { + // TODO, if we support multiple servers, this need to be changed + cr := beego.BeeApp.Handlers + return func(ctx *context.Context) { + span := opentracing.SpanFromContext(ctx.Request.Context()) + spanCtx := ctx.Request.Context() + if span == nil { + operationName := ctx.Input.URL() + // it means that there is not any span, so we create a span as the root span. + route, found := cr.FindRouter(ctx) + if found { + operationName = route.GetPattern() + } + span, spanCtx = opentracing.StartSpanFromContext(spanCtx, operationName) + } + defer span.Finish() + next(ctx) + // if you think we need to do more things, feel free to create an issue to tell us + span.SetTag("status", ctx.Output.Status) + span.SetTag("method", ctx.Input.Method()) + span.SetTag("route", ctx.Input.GetData("RouterPattern")) + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx) + } + } +} diff --git a/pkg/web/filter/opentracing/filter_test.go b/pkg/web/filter/opentracing/filter_test.go new file mode 100644 index 00000000..65f1f24e --- /dev/null +++ b/pkg/web/filter/opentracing/filter_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/context" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + builder := &FilterChainBuilder{ + CustomSpanFunc: func(span opentracing.Span, ctx *context.Context) { + span.SetTag("aa", "bbb") + }, + } + + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + ctx.Input.SetData("RouterPattern", "my-route") + + filterFunc := builder.FilterChain(func(ctx *context.Context) { + ctx.Input.SetData("opentracing", true) + }) + + filterFunc(ctx) + assert.True(t, ctx.Input.GetData("opentracing").(bool)) +} diff --git a/pkg/metric/prometheus.go b/pkg/web/filter/prometheus/filter.go similarity index 57% rename from pkg/metric/prometheus.go rename to pkg/web/filter/prometheus/filter.go index 6a97aec5..bd47dcec 100644 --- a/pkg/metric/prometheus.go +++ b/pkg/web/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 astaxie +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,22 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package metric +package prometheus import ( - "net/http" - "reflect" "strconv" "strings" "time" "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/logs" + beego "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" ) -func PrometheusMiddleWare(next http.Handler) http.Handler { +// FilterChainBuilder is an extension point, +// when we want to support some configuration, +// please use this structure +type FilterChainBuilder struct { +} + +// FilterChain returns a FilterFunc. The filter will records some metrics +func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", Subsystem: "http_request", @@ -43,12 +48,12 @@ func PrometheusMiddleWare(next http.Handler) http.Handler { registerBuildInfo() - return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { - start := time.Now() - next.ServeHTTP(writer, q) - end := time.Now() - go report(end.Sub(start), writer, q, summaryVec) - }) + return func(ctx *context.Context) { + startTime := time.Now() + next(ctx) + endTime := time.Now() + go report(endTime.Sub(startTime), ctx, summaryVec) + } } func registerBuildInfo() { @@ -73,27 +78,9 @@ func registerBuildInfo() { buildInfo.WithLabelValues().Set(1) } -func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { - ctrl := beego.BeeApp.Handlers - ctx := ctrl.GetContext() - ctx.Reset(writer, q) - defer ctrl.GiveBackContext(ctx) - - // We cannot read the status code from q.Response.StatusCode - // since the http server does not set q.Response. So q.Response is nil - // Thus, we use reflection to read the status from writer whose concrete type is http.response - responseVal := reflect.ValueOf(writer).Elem() - field := responseVal.FieldByName("status") - status := -1 - if field.IsValid() && field.Kind() == reflect.Int { - status = int(field.Int()) - } - ptn := "UNKNOWN" - if rt, found := ctrl.FindRouter(ctx); found { - ptn = rt.GetPattern() - } else { - logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) - } +func report(dur time.Duration, ctx *context.Context, vec *prometheus.SummaryVec) { + status := ctx.Output.Status + ptn := ctx.Input.GetData("RouterPattern").(string) ms := dur / time.Millisecond - vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) + vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) } diff --git a/pkg/metric/prometheus_test.go b/pkg/web/filter/prometheus/filter_test.go similarity index 51% rename from pkg/metric/prometheus_test.go rename to pkg/web/filter/prometheus/filter_test.go index e04c3285..7d2e2acf 100644 --- a/pkg/metric/prometheus_test.go +++ b/pkg/web/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 astaxie +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,31 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -package metric +package prometheus import ( "net/http" - "net/url" + "net/http/httptest" "testing" - "time" - "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/astaxie/beego/pkg/context" ) -func TestPrometheusMiddleWare(t *testing.T) { - middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) - writer := &context.Response{} - request := &http.Request{ - URL: &url.URL{ - Host: "localhost", - RawPath: "/a/b/c", - }, - Method: "POST", - } - vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) +func TestFilterChain(t *testing.T) { + filter := (&FilterChainBuilder{}).FilterChain(func(ctx *context.Context) { + // do nothing + ctx.Input.SetData("invocation", true) + }) - report(time.Second, writer, request, vec) - middleware.ServeHTTP(writer, request) + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + ctx.Input.SetData("RouterPattern", "my-route") + filter(ctx) + assert.True(t, ctx.Input.GetData("invocation").(bool)) } From 261b704d8b4968e31a68fb31740c02cf315c837a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 5 Aug 2020 06:34:54 +0000 Subject: [PATCH 065/207] Fix UT --- .gitignore | 9 +- pkg/template_test.go | 4 +- pkg/web/filter/opentracing/filter.go | 8 +- template_test.go | 4 +- {pkg/testdata => test}/Makefile | 0 {pkg/testdata => test}/bindata.go | 2 +- {pkg/testdata => test}/views/blocks/block.tpl | 0 {pkg/testdata => test}/views/header.tpl | 0 {pkg/testdata => test}/views/index.tpl | 0 testdata/Makefile | 2 - testdata/bindata.go | 296 ------------------ testdata/views/blocks/block.tpl | 3 - testdata/views/header.tpl | 3 - testdata/views/index.tpl | 15 - utils/file_test.go | 7 +- 15 files changed, 19 insertions(+), 334 deletions(-) rename {pkg/testdata => test}/Makefile (100%) rename {pkg/testdata => test}/bindata.go (99%) rename {pkg/testdata => test}/views/blocks/block.tpl (100%) rename {pkg/testdata => test}/views/header.tpl (100%) rename {pkg/testdata => test}/views/index.tpl (100%) delete mode 100644 testdata/Makefile delete mode 100644 testdata/bindata.go delete mode 100644 testdata/views/blocks/block.tpl delete mode 100644 testdata/views/header.tpl delete mode 100644 testdata/views/index.tpl diff --git a/.gitignore b/.gitignore index b70c76c4..304c4b73 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ *.swo beego.iml -_beeTmp -_beeTmp2 -pkg/_beeTmp -pkg/_beeTmp2 +_beeTmp/ +_beeTmp2/ +pkg/_beeTmp/ +pkg/_beeTmp2/ +test/tmp/ diff --git a/pkg/template_test.go b/pkg/template_test.go index af948190..6e4a27fc 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -24,7 +24,7 @@ import ( "github.com/elazarl/go-bindata-assetfs" "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/testdata" + "github.com/astaxie/beego/test" ) var header = `{{define "header"}} @@ -308,7 +308,7 @@ var outputBinData = ` func TestFsBinData(t *testing.T) { SetTemplateFSFunc(func() http.FileSystem { - return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}} }) dir := "views" if err := AddViewPath("views"); err != nil { diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go index ef6521f7..8e332c7d 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/web/filter/opentracing/filter.go @@ -29,20 +29,22 @@ type FilterChainBuilder struct { func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { - // TODO, if we support multiple servers, this need to be changed - cr := beego.BeeApp.Handlers return func(ctx *context.Context) { span := opentracing.SpanFromContext(ctx.Request.Context()) spanCtx := ctx.Request.Context() if span == nil { operationName := ctx.Input.URL() // it means that there is not any span, so we create a span as the root span. - route, found := cr.FindRouter(ctx) + // TODO, if we support multiple servers, this need to be changed + route, found := beego.BeeApp.Handlers.FindRouter(ctx) if found { operationName = route.GetPattern() } span, spanCtx = opentracing.StartSpanFromContext(spanCtx, operationName) + newReq := ctx.Request.Clone(spanCtx) + ctx.Reset(ctx.ResponseWriter.ResponseWriter, newReq) } + defer span.Finish() next(ctx) // if you think we need to do more things, feel free to create an issue to tell us diff --git a/template_test.go b/template_test.go index 287faadc..bde9c100 100644 --- a/template_test.go +++ b/template_test.go @@ -16,7 +16,7 @@ package beego import ( "bytes" - "github.com/astaxie/beego/testdata" + "github.com/astaxie/beego/test" "github.com/elazarl/go-bindata-assetfs" "net/http" "os" @@ -291,7 +291,7 @@ var outputBinData = ` func TestFsBinData(t *testing.T) { SetTemplateFSFunc(func() http.FileSystem { - return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}} }) dir := "views" if err := AddViewPath("views"); err != nil { diff --git a/pkg/testdata/Makefile b/test/Makefile similarity index 100% rename from pkg/testdata/Makefile rename to test/Makefile diff --git a/pkg/testdata/bindata.go b/test/bindata.go similarity index 99% rename from pkg/testdata/bindata.go rename to test/bindata.go index beade103..9fda5075 100644 --- a/pkg/testdata/bindata.go +++ b/test/bindata.go @@ -5,7 +5,7 @@ // views/index.tpl // DO NOT EDIT! -package testdata +package test import ( "bytes" diff --git a/pkg/testdata/views/blocks/block.tpl b/test/views/blocks/block.tpl similarity index 100% rename from pkg/testdata/views/blocks/block.tpl rename to test/views/blocks/block.tpl diff --git a/pkg/testdata/views/header.tpl b/test/views/header.tpl similarity index 100% rename from pkg/testdata/views/header.tpl rename to test/views/header.tpl diff --git a/pkg/testdata/views/index.tpl b/test/views/index.tpl similarity index 100% rename from pkg/testdata/views/index.tpl rename to test/views/index.tpl diff --git a/testdata/Makefile b/testdata/Makefile deleted file mode 100644 index e80e8238..00000000 --- a/testdata/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -build_view: - $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/testdata/bindata.go b/testdata/bindata.go deleted file mode 100644 index beade103..00000000 --- a/testdata/bindata.go +++ /dev/null @@ -1,296 +0,0 @@ -// Code generated by go-bindata. -// sources: -// views/blocks/block.tpl -// views/header.tpl -// views/index.tpl -// DO NOT EDIT! - -package testdata - -import ( - "bytes" - "compress/gzip" - "fmt" - "github.com/elazarl/go-bindata-assetfs" - "io" - "io/ioutil" - "os" - "path/filepath" - "strings" - "time" -) - -func bindataRead(data []byte, name string) ([]byte, error) { - gz, err := gzip.NewReader(bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("Read %q: %v", name, err) - } - - var buf bytes.Buffer - _, err = io.Copy(&buf, gz) - clErr := gz.Close() - - if err != nil { - return nil, fmt.Errorf("Read %q: %v", name, err) - } - if clErr != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -type asset struct { - bytes []byte - info os.FileInfo -} - -type bindataFileInfo struct { - name string - size int64 - mode os.FileMode - modTime time.Time -} - -func (fi bindataFileInfo) Name() string { - return fi.name -} -func (fi bindataFileInfo) Size() int64 { - return fi.size -} -func (fi bindataFileInfo) Mode() os.FileMode { - return fi.mode -} -func (fi bindataFileInfo) ModTime() time.Time { - return fi.modTime -} -func (fi bindataFileInfo) IsDir() bool { - return false -} -func (fi bindataFileInfo) Sys() interface{} { - return nil -} - -var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00") - -func viewsBlocksBlockTplBytes() ([]byte, error) { - return bindataRead( - _viewsBlocksBlockTpl, - "views/blocks/block.tpl", - ) -} - -func viewsBlocksBlockTpl() (*asset, error) { - bytes, err := viewsBlocksBlockTplBytes() - if err != nil { - return nil, err - } - - info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} - a := &asset{bytes: bytes, info: info} - return a, nil -} - -var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00") - -func viewsHeaderTplBytes() ([]byte, error) { - return bindataRead( - _viewsHeaderTpl, - "views/header.tpl", - ) -} - -func viewsHeaderTpl() (*asset, error) { - bytes, err := viewsHeaderTplBytes() - if err != nil { - return nil, err - } - - info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} - a := &asset{bytes: bytes, info: info} - return a, nil -} - -var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00") - -func viewsIndexTplBytes() ([]byte, error) { - return bindataRead( - _viewsIndexTpl, - "views/index.tpl", - ) -} - -func viewsIndexTpl() (*asset, error) { - bytes, err := viewsIndexTplBytes() - if err != nil { - return nil, err - } - - info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} - a := &asset{bytes: bytes, info: info} - return a, nil -} - -// Asset loads and returns the asset for the given name. -// It returns an error if the asset could not be found or -// could not be loaded. -func Asset(name string) ([]byte, error) { - cannonicalName := strings.Replace(name, "\\", "/", -1) - if f, ok := _bindata[cannonicalName]; ok { - a, err := f() - if err != nil { - return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) - } - return a.bytes, nil - } - return nil, fmt.Errorf("Asset %s not found", name) -} - -// MustAsset is like Asset but panics when Asset would return an error. -// It simplifies safe initialization of global variables. -func MustAsset(name string) []byte { - a, err := Asset(name) - if err != nil { - panic("asset: Asset(" + name + "): " + err.Error()) - } - - return a -} - -// AssetInfo loads and returns the asset info for the given name. -// It returns an error if the asset could not be found or -// could not be loaded. -func AssetInfo(name string) (os.FileInfo, error) { - cannonicalName := strings.Replace(name, "\\", "/", -1) - if f, ok := _bindata[cannonicalName]; ok { - a, err := f() - if err != nil { - return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) - } - return a.info, nil - } - return nil, fmt.Errorf("AssetInfo %s not found", name) -} - -// AssetNames returns the names of the assets. -func AssetNames() []string { - names := make([]string, 0, len(_bindata)) - for name := range _bindata { - names = append(names, name) - } - return names -} - -// _bindata is a table, holding each asset generator, mapped to its name. -var _bindata = map[string]func() (*asset, error){ - "views/blocks/block.tpl": viewsBlocksBlockTpl, - "views/header.tpl": viewsHeaderTpl, - "views/index.tpl": viewsIndexTpl, -} - -// AssetDir returns the file names below a certain -// directory embedded in the file by go-bindata. -// For example if you run go-bindata on data/... and data contains the -// following hierarchy: -// data/ -// foo.txt -// img/ -// a.png -// b.png -// then AssetDir("data") would return []string{"foo.txt", "img"} -// AssetDir("data/img") would return []string{"a.png", "b.png"} -// AssetDir("foo.txt") and AssetDir("notexist") would return an error -// AssetDir("") will return []string{"data"}. -func AssetDir(name string) ([]string, error) { - node := _bintree - if len(name) != 0 { - cannonicalName := strings.Replace(name, "\\", "/", -1) - pathList := strings.Split(cannonicalName, "/") - for _, p := range pathList { - node = node.Children[p] - if node == nil { - return nil, fmt.Errorf("Asset %s not found", name) - } - } - } - if node.Func != nil { - return nil, fmt.Errorf("Asset %s not found", name) - } - rv := make([]string, 0, len(node.Children)) - for childName := range node.Children { - rv = append(rv, childName) - } - return rv, nil -} - -type bintree struct { - Func func() (*asset, error) - Children map[string]*bintree -} - -var _bintree = &bintree{nil, map[string]*bintree{ - "views": &bintree{nil, map[string]*bintree{ - "blocks": &bintree{nil, map[string]*bintree{ - "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, - }}, - "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, - "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, - }}, -}} - -// RestoreAsset restores an asset under the given directory -func RestoreAsset(dir, name string) error { - data, err := Asset(name) - if err != nil { - return err - } - info, err := AssetInfo(name) - if err != nil { - return err - } - err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) - if err != nil { - return err - } - err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) - if err != nil { - return err - } - err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) - if err != nil { - return err - } - return nil -} - -// RestoreAssets restores an asset under the given directory recursively -func RestoreAssets(dir, name string) error { - children, err := AssetDir(name) - // File - if err != nil { - return RestoreAsset(dir, name) - } - // Dir - for _, child := range children { - err = RestoreAssets(dir, filepath.Join(name, child)) - if err != nil { - return err - } - } - return nil -} - -func _filePath(dir, name string) string { - cannonicalName := strings.Replace(name, "\\", "/", -1) - return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) -} - -func assetFS() *assetfs.AssetFS { - assetInfo := func(path string) (os.FileInfo, error) { - return os.Stat(path) - } - for k := range _bintree.Children { - return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} - } - panic("unreachable") -} diff --git a/testdata/views/blocks/block.tpl b/testdata/views/blocks/block.tpl deleted file mode 100644 index 2a9c57fc..00000000 --- a/testdata/views/blocks/block.tpl +++ /dev/null @@ -1,3 +0,0 @@ -{{define "block"}} -

Hello, blocks!

-{{end}} \ No newline at end of file diff --git a/testdata/views/header.tpl b/testdata/views/header.tpl deleted file mode 100644 index 041fa403..00000000 --- a/testdata/views/header.tpl +++ /dev/null @@ -1,3 +0,0 @@ -{{define "header"}} -

Hello, astaxie!

-{{end}} \ No newline at end of file diff --git a/testdata/views/index.tpl b/testdata/views/index.tpl deleted file mode 100644 index 21b7fc06..00000000 --- a/testdata/views/index.tpl +++ /dev/null @@ -1,15 +0,0 @@ - - - - beego welcome template - - - - {{template "block"}} - {{template "header"}} - {{template "blocks/block.tpl"}} - -

{{ .Title }}

-

This is SomeVar: {{ .SomeVar }}

- - diff --git a/utils/file_test.go b/utils/file_test.go index b2644157..84443e20 100644 --- a/utils/file_test.go +++ b/utils/file_test.go @@ -18,6 +18,8 @@ import ( "path/filepath" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) var noExistedFile = "/tmp/not_existed_file" @@ -66,9 +68,8 @@ func TestGrepFile(t *testing.T) { path := filepath.Join(".", "testdata", "grepe.test") lines, err := GrepFile(`^\s*[^#]+`, path) - if err != nil { - t.Error(err) - } + assert.Nil(t, err) + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { t.Errorf("expect [hello world], but receive %v", lines) } From 882aa9b9674961f1708473fa26d7097a8bae8ce5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 5 Aug 2020 21:57:20 +0800 Subject: [PATCH 066/207] Deprecated old web module --- admin.go | 3 +++ app.go | 20 +++++++++++++++ beego.go | 9 +++++++ build_info.go | 7 +++++ config.go | 27 ++++++++++++++++++++ controller.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ doc.go | 2 ++ error.go | 4 +++ filter.go | 3 +++ flash.go | 9 +++++++ fs.go | 3 +++ namespace.go | 37 +++++++++++++++++++++++++++ policy.go | 3 +++ router.go | 29 +++++++++++++++++++++ template.go | 11 ++++++++ templatefunc.go | 18 +++++++++++++ tree.go | 5 ++++ 17 files changed, 258 insertions(+) diff --git a/admin.go b/admin.go index db52647e..d9c96dfd 100644 --- a/admin.go +++ b/admin.go @@ -52,6 +52,7 @@ var beeAdminApp *adminApp // return true // } // beego.FilterMonitorFunc = MyFilterMonitor. +// Deprecated: using pkg/, we will delete this in v2.1.0 var FilterMonitorFunc func(string, string, time.Duration, string, int) bool func init() { @@ -201,6 +202,7 @@ func list(root string, p interface{}, m M) { } // PrintTree prints all registered routers. +// Deprecated: using pkg/, we will delete this in v2.1.0 func PrintTree() M { var ( content = M{} @@ -432,6 +434,7 @@ func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { // Run adminApp http server. // Its addr is defined in configuration file as adminhttpaddr and adminhttpport. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (admin *adminApp) Run() { if len(toolbox.AdminTaskList) > 0 { toolbox.StartTask() diff --git a/app.go b/app.go index 3dee8999..d86188c0 100644 --- a/app.go +++ b/app.go @@ -35,6 +35,7 @@ import ( var ( // BeeApp is an application instance + // Deprecated: using pkg/, we will delete this in v2.1.0 BeeApp *App ) @@ -50,6 +51,7 @@ type App struct { } // NewApp returns a new beego application. +// Deprecated: using pkg/, we will delete this in v2.1.0 func NewApp() *App { cr := NewControllerRegister() app := &App{Handlers: cr, Server: &http.Server{}} @@ -57,9 +59,11 @@ func NewApp() *App { } // MiddleWare function for http.Handler +// Deprecated: using pkg/, we will delete this in v2.1.0 type MiddleWare func(http.Handler) http.Handler // Run beego application. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (app *App) Run(mws ...MiddleWare) { addr := BConfig.Listen.HTTPAddr @@ -254,6 +258,7 @@ func (app *App) Run(mws ...MiddleWare) { // beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +// Deprecated: using pkg/, we will delete this in v2.1.0 func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { BeeApp.Handlers.Add(rootpath, c, mappingMethods...) return BeeApp @@ -268,6 +273,7 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *A // Usage (replace "GET" with "*" for all methods): // beego.UnregisterFixedRoute("/yourpreviouspath", "GET") // beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +// Deprecated: using pkg/, we will delete this in v2.1.0 func UnregisterFixedRoute(fixedRoute string, method string) *App { subPaths := splitPath(fixedRoute) if method == "" || method == "*" { @@ -364,6 +370,7 @@ func findAndRemoveSingleTree(entryPointTree *Tree) { // the comments @router url methodlist // url support all the function Router's pattern // methodlist [get post head put delete options *] +// Deprecated: using pkg/, we will delete this in v2.1.0 func Include(cList ...ControllerInterface) *App { BeeApp.Handlers.Include(cList...) return BeeApp @@ -372,6 +379,7 @@ func Include(cList ...ControllerInterface) *App { // RESTRouter adds a restful controller handler to BeeApp. // its' controller implements beego.ControllerInterface and // defines a param "pattern/:objectId" to visit each resource. +// Deprecated: using pkg/, we will delete this in v2.1.0 func RESTRouter(rootpath string, c ControllerInterface) *App { Router(rootpath, c) Router(path.Join(rootpath, ":objectId"), c) @@ -382,6 +390,7 @@ func RESTRouter(rootpath string, c ControllerInterface) *App { // it's same to App.AutoRouter. // if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, // visit the url /main/list to exec List function or /main/page to exec Page function. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AutoRouter(c ControllerInterface) *App { BeeApp.Handlers.AddAuto(c) return BeeApp @@ -391,6 +400,7 @@ func AutoRouter(c ControllerInterface) *App { // it's same to App.AutoRouterWithPrefix. // if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, // visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AutoPrefix(prefix string, c ControllerInterface) *App { BeeApp.Handlers.AddAutoPrefix(prefix, c) return BeeApp @@ -401,6 +411,7 @@ func AutoPrefix(prefix string, c ControllerInterface) *App { // beego.Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Get(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Get(rootpath, f) return BeeApp @@ -411,6 +422,7 @@ func Get(rootpath string, f FilterFunc) *App { // beego.Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Post(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Post(rootpath, f) return BeeApp @@ -421,6 +433,7 @@ func Post(rootpath string, f FilterFunc) *App { // beego.Delete("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Delete(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Delete(rootpath, f) return BeeApp @@ -431,6 +444,7 @@ func Delete(rootpath string, f FilterFunc) *App { // beego.Put("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Put(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Put(rootpath, f) return BeeApp @@ -441,6 +455,7 @@ func Put(rootpath string, f FilterFunc) *App { // beego.Head("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Head(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Head(rootpath, f) return BeeApp @@ -451,6 +466,7 @@ func Head(rootpath string, f FilterFunc) *App { // beego.Options("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Options(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Options(rootpath, f) return BeeApp @@ -461,6 +477,7 @@ func Options(rootpath string, f FilterFunc) *App { // beego.Patch("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Patch(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Patch(rootpath, f) return BeeApp @@ -471,6 +488,7 @@ func Patch(rootpath string, f FilterFunc) *App { // beego.Any("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Any(rootpath string, f FilterFunc) *App { BeeApp.Handlers.Any(rootpath, f) return BeeApp @@ -481,6 +499,7 @@ func Any(rootpath string, f FilterFunc) *App { // beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { // fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) // })) +// Deprecated: using pkg/, we will delete this in v2.1.0 func Handler(rootpath string, h http.Handler, options ...interface{}) *App { BeeApp.Handlers.Handler(rootpath, h, options...) return BeeApp @@ -490,6 +509,7 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +// Deprecated: using pkg/, we will delete this in v2.1.0 func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) return BeeApp diff --git a/beego.go b/beego.go index 8ebe0bab..ef93134d 100644 --- a/beego.go +++ b/beego.go @@ -23,15 +23,19 @@ import ( const ( // VERSION represent beego web framework version. + // Deprecated: using pkg/, we will delete this in v2.1.0 VERSION = "1.12.2" // DEV is for develop + // Deprecated: using pkg/, we will delete this in v2.1.0 DEV = "dev" // PROD is for production + // Deprecated: using pkg/, we will delete this in v2.1.0 PROD = "prod" ) // M is Map shortcut +// Deprecated: using pkg/, we will delete this in v2.1.0 type M map[string]interface{} // Hook function to run @@ -44,6 +48,7 @@ var ( // AddAPPStartHook is used to register the hookfunc // The hookfuncs will run in beego.Run() // such as initiating session , starting middleware , building template, starting admin control and so on. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AddAPPStartHook(hf ...hookfunc) { hooks = append(hooks, hf...) } @@ -53,6 +58,7 @@ func AddAPPStartHook(hf ...hookfunc) { // beego.Run("localhost") // beego.Run(":8089") // beego.Run("127.0.0.1:8089") +// Deprecated: using pkg/, we will delete this in v2.1.0 func Run(params ...string) { initBeforeHTTPRun() @@ -73,6 +79,7 @@ func Run(params ...string) { } // RunWithMiddleWares Run beego application with middlewares. +// Deprecated: using pkg/, we will delete this in v2.1.0 func RunWithMiddleWares(addr string, mws ...MiddleWare) { initBeforeHTTPRun() @@ -107,6 +114,7 @@ func initBeforeHTTPRun() { } // TestBeegoInit is for test package init +// Deprecated: using pkg/, we will delete this in v2.1.0 func TestBeegoInit(ap string) { path := filepath.Join(ap, "conf", "app.conf") os.Chdir(ap) @@ -114,6 +122,7 @@ func TestBeegoInit(ap string) { } // InitBeegoBeforeTest is for test package init +// Deprecated: using pkg/, we will delete this in v2.1.0 func InitBeegoBeforeTest(appConfigPath string) { if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { panic(err) diff --git a/build_info.go b/build_info.go index c31152ea..896bbdf3 100644 --- a/build_info.go +++ b/build_info.go @@ -15,13 +15,20 @@ package beego var ( + // Deprecated: using pkg/, we will delete this in v2.1.0 BuildVersion string + // Deprecated: using pkg/, we will delete this in v2.1.0 BuildGitRevision string + // Deprecated: using pkg/, we will delete this in v2.1.0 BuildStatus string + // Deprecated: using pkg/, we will delete this in v2.1.0 BuildTag string + // Deprecated: using pkg/, we will delete this in v2.1.0 BuildTime string + // Deprecated: using pkg/, we will delete this in v2.1.0 GoVersion string + // Deprecated: using pkg/, we will delete this in v2.1.0 GitBranch string ) diff --git a/config.go b/config.go index 0c995293..d707542a 100644 --- a/config.go +++ b/config.go @@ -31,6 +31,7 @@ import ( ) // Config is the main struct for BConfig +// Deprecated: using pkg/, we will delete this in v2.1.0 type Config struct { AppName string //Application name RunMode string //Running Mode: dev | prod @@ -49,6 +50,7 @@ type Config struct { } // Listen holds for http and https related config +// Deprecated: using pkg/, we will delete this in v2.1.0 type Listen struct { Graceful bool // Graceful means use graceful module to start the server ServerTimeOut int64 @@ -75,6 +77,7 @@ type Listen struct { } // WebConfig holds web related config +// Deprecated: using pkg/, we will delete this in v2.1.0 type WebConfig struct { AutoRender bool EnableDocs bool @@ -95,6 +98,7 @@ type WebConfig struct { } // SessionConfig holds session related config +// Deprecated: using pkg/, we will delete this in v2.1.0 type SessionConfig struct { SessionOn bool SessionProvider string @@ -111,6 +115,7 @@ type SessionConfig struct { } // LogConfig holds Log related config +// Deprecated: using pkg/, we will delete this in v2.1.0 type LogConfig struct { AccessLogs bool EnableStaticLogs bool //log static files requests default: false @@ -121,12 +126,16 @@ type LogConfig struct { var ( // BConfig is the default config for Application + // Deprecated: using pkg/, we will delete this in v2.1.0 BConfig *Config // AppConfig is the instance of Config, store the config information from file + // Deprecated: using pkg/, we will delete this in v2.1.0 AppConfig *beegoAppConfig // AppPath is the absolute path to the app + // Deprecated: using pkg/, we will delete this in v2.1.0 AppPath string // GlobalSessions is the instance for the session manager + // Deprecated: using pkg/, we will delete this in v2.1.0 GlobalSessions *session.Manager // appConfigPath is the path to the config files @@ -134,6 +143,7 @@ var ( // appConfigProvider is the provider for the config, default is ini appConfigProvider = "ini" // WorkPath is the absolute path to project root directory + // Deprecated: using pkg/, we will delete this in v2.1.0 WorkPath string ) @@ -398,6 +408,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) { } // LoadAppConfig allow developer to apply a config file +// Deprecated: using pkg/, we will delete this in v2.1.0 func LoadAppConfig(adapterName, configPath string) error { absConfigPath, err := filepath.Abs(configPath) if err != nil { @@ -426,6 +437,7 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err return &beegoAppConfig{ac}, nil } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Set(key, val string) error { if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { return b.innerConfig.Set(key, val) @@ -433,6 +445,7 @@ func (b *beegoAppConfig) Set(key, val string) error { return nil } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) String(key string) string { if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { return v @@ -440,6 +453,7 @@ func (b *beegoAppConfig) String(key string) string { return b.innerConfig.String(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Strings(key string) []string { if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { return v @@ -447,6 +461,7 @@ func (b *beegoAppConfig) Strings(key string) []string { return b.innerConfig.Strings(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Int(key string) (int, error) { if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { return v, nil @@ -454,6 +469,7 @@ func (b *beegoAppConfig) Int(key string) (int, error) { return b.innerConfig.Int(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Int64(key string) (int64, error) { if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { return v, nil @@ -461,6 +477,7 @@ func (b *beegoAppConfig) Int64(key string) (int64, error) { return b.innerConfig.Int64(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Bool(key string) (bool, error) { if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { return v, nil @@ -468,6 +485,7 @@ func (b *beegoAppConfig) Bool(key string) (bool, error) { return b.innerConfig.Bool(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) Float(key string) (float64, error) { if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { return v, nil @@ -475,6 +493,7 @@ func (b *beegoAppConfig) Float(key string) (float64, error) { return b.innerConfig.Float(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { if v := b.String(key); v != "" { return v @@ -482,6 +501,7 @@ func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { if v := b.Strings(key); len(v) != 0 { return v @@ -489,6 +509,7 @@ func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []strin return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { if v, err := b.Int(key); err == nil { return v @@ -496,6 +517,7 @@ func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { if v, err := b.Int64(key); err == nil { return v @@ -503,6 +525,7 @@ func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { if v, err := b.Bool(key); err == nil { return v @@ -510,6 +533,7 @@ func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { if v, err := b.Float(key); err == nil { return v @@ -517,14 +541,17 @@ func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { return defaultVal } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) DIY(key string) (interface{}, error) { return b.innerConfig.DIY(key) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { return b.innerConfig.GetSection(section) } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (b *beegoAppConfig) SaveConfigFile(filename string) error { return b.innerConfig.SaveConfigFile(filename) } diff --git a/controller.go b/controller.go index 0e8853b3..0b4a79a8 100644 --- a/controller.go +++ b/controller.go @@ -35,12 +35,15 @@ import ( var ( // ErrAbort custom error when user stop request handler manually. + // Deprecated: using pkg/, we will delete this in v2.1.0 ErrAbort = errors.New("user stop run") // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + // Deprecated: using pkg/, we will delete this in v2.1.0 GlobalControllerRouter = make(map[string][]ControllerComments) ) // ControllerFilter store the filter for controller +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerFilter struct { Pattern string Pos int @@ -50,6 +53,7 @@ type ControllerFilter struct { } // ControllerFilterComments store the comment for controller level filter +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerFilterComments struct { Pattern string Pos int @@ -59,12 +63,14 @@ type ControllerFilterComments struct { } // ControllerImportComments store the import comment for controller needed +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerImportComments struct { ImportPath string ImportAlias string } // ControllerComments store the comment for the controller method +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerComments struct { Method string Router string @@ -77,6 +83,7 @@ type ControllerComments struct { } // ControllerCommentsSlice implements the sort interface +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerCommentsSlice []ControllerComments func (p ControllerCommentsSlice) Len() int { return len(p) } @@ -85,6 +92,7 @@ func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // Controller defines some basic http request handler operations, such as // http context, template and view, session and xsrf. +// Deprecated: using pkg/, we will delete this in v2.1.0 type Controller struct { // context data Ctx *context.Context @@ -115,6 +123,7 @@ type Controller struct { } // ControllerInterface is an interface to uniform all controller handler. +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerInterface interface { Init(ct *context.Context, controllerName, actionName string, app interface{}) Prepare() @@ -135,6 +144,7 @@ type ControllerInterface interface { } // Init generates default values of controller operations. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { c.Layout = "" c.TplName = "" @@ -150,42 +160,51 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin } // Prepare runs after Init before request function execution. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Prepare() {} // Finish runs after request function execution. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Finish() {} // Get adds a request function to handle GET request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Get() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Post adds a request function to handle POST request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Post() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Delete adds a request function to handle DELETE request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Delete() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Put adds a request function to handle PUT request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Put() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Head adds a request function to handle HEAD request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Head() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Patch adds a request function to handle PATCH request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Patch() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Options adds a request function to handle OPTIONS request. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Options() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } @@ -198,6 +217,7 @@ func (c *Controller) Options() { // reflect the message received, excluding some fields described below, // back to the client as the message body of a 200 (OK) response with a // Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Trace() { ts := func(h http.Header) (hs string) { for k, v := range h { @@ -213,6 +233,7 @@ func (c *Controller) Trace() { } // HandlerFunc call function with the name +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) HandlerFunc(fnname string) bool { if v, ok := c.methodMapping[fnname]; ok { v() @@ -222,14 +243,17 @@ func (c *Controller) HandlerFunc(fnname string) bool { } // URLMapping register the internal Controller router. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) URLMapping() {} // Mapping the method to function +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Mapping(method string, fn func()) { c.methodMapping[method] = fn } // Render sends the response with rendered template bytes as text/html type. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Render() error { if !c.EnableRender { return nil @@ -247,12 +271,14 @@ func (c *Controller) Render() error { } // RenderString returns the rendered template string. Do not send out response. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) RenderString() (string, error) { b, e := c.RenderBytes() return string(b), e } // RenderBytes returns the bytes of rendered template string. Do not send out response. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) RenderBytes() ([]byte, error) { buf, err := c.renderTemplate() //if the controller has set layout, then first get the tplName's content set the content to the layout @@ -314,12 +340,14 @@ func (c *Controller) viewPath() string { } // Redirect sends the redirection response to url with status code. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Redirect(url string, code int) { LogAccess(c.Ctx, nil, code) c.Ctx.Redirect(code, url) } // SetData set the data depending on the accepted +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) SetData(data interface{}) { accept := c.Ctx.Input.Header("Accept") switch accept { @@ -333,6 +361,7 @@ func (c *Controller) SetData(data interface{}) { } // Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Abort(code string) { status, err := strconv.Atoi(code) if err != nil { @@ -342,6 +371,7 @@ func (c *Controller) Abort(code string) { } // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) CustomAbort(status int, body string) { // first panic from ErrorMaps, it is user defined error functions. if _, ok := ErrorMaps[body]; ok { @@ -355,12 +385,14 @@ func (c *Controller) CustomAbort(status int, body string) { } // StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) StopRun() { panic(ErrAbort) } // URLFor does another controller handler in this request function. // it goes to this controller method if endpoint is not clear. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) URLFor(endpoint string, values ...interface{}) string { if len(endpoint) == 0 { return "" @@ -372,6 +404,7 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string { } // ServeJSON sends a json response with encoding charset. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ServeJSON(encoding ...bool) { var ( hasIndent = BConfig.RunMode != PROD @@ -382,23 +415,27 @@ func (c *Controller) ServeJSON(encoding ...bool) { } // ServeJSONP sends a jsonp response. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ServeJSONP() { hasIndent := BConfig.RunMode != PROD c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) } // ServeXML sends xml response. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ServeXML() { hasIndent := BConfig.RunMode != PROD c.Ctx.Output.XML(c.Data["xml"], hasIndent) } // ServeYAML sends yaml response. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ServeYAML() { c.Ctx.Output.YAML(c.Data["yaml"]) } // ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ServeFormatted(encoding ...bool) { hasIndent := BConfig.RunMode != PROD hasEncoding := len(encoding) > 0 && encoding[0] @@ -406,6 +443,7 @@ func (c *Controller) ServeFormatted(encoding ...bool) { } // Input returns the input data map from POST or PUT request body and query string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) Input() url.Values { if c.Ctx.Request.Form == nil { c.Ctx.Request.ParseForm() @@ -414,11 +452,13 @@ func (c *Controller) Input() url.Values { } // ParseForm maps input data map to obj struct. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) ParseForm(obj interface{}) error { return ParseForm(c.Input(), obj) } // GetString returns the input value by key string or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetString(key string, def ...string) string { if v := c.Ctx.Input.Query(key); v != "" { return v @@ -431,6 +471,7 @@ func (c *Controller) GetString(key string, def ...string) string { // GetStrings returns the input string slice by key string or the default value while it's present and input is blank // it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetStrings(key string, def ...[]string) []string { var defv []string if len(def) > 0 { @@ -447,6 +488,7 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string { } // GetInt returns input as an int or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetInt(key string, def ...int) (int, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -456,6 +498,7 @@ func (c *Controller) GetInt(key string, def ...int) (int, error) { } // GetInt8 return input as an int8 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -466,6 +509,7 @@ func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { } // GetUint8 return input as an uint8 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -476,6 +520,7 @@ func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { } // GetInt16 returns input as an int16 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -486,6 +531,7 @@ func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { } // GetUint16 returns input as an uint16 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -496,6 +542,7 @@ func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { } // GetInt32 returns input as an int32 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -506,6 +553,7 @@ func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { } // GetUint32 returns input as an uint32 or the default value while it's present and input is blank +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -516,6 +564,7 @@ func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { } // GetInt64 returns input value as int64 or the default value while it's present and input is blank. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -525,6 +574,7 @@ func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { } // GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -534,6 +584,7 @@ func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { } // GetBool returns input value as bool or the default value while it's present and input is blank. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetBool(key string, def ...bool) (bool, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -543,6 +594,7 @@ func (c *Controller) GetBool(key string, def ...bool) (bool, error) { } // GetFloat returns input value as float64 or the default value while it's present and input is blank. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { strv := c.Ctx.Input.Query(key) if len(strv) == 0 && len(def) > 0 { @@ -553,6 +605,7 @@ func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { // GetFile returns the file data in file upload field named as key. // it returns the first one of multi-uploaded files. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { return c.Ctx.Request.FormFile(key) } @@ -584,6 +637,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, // return // } // } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { return files, nil @@ -593,6 +647,7 @@ func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { // SaveToFile saves uploaded file to new path. // it only operates the first one of mutil-upload form file field. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) SaveToFile(fromfile, tofile string) error { file, _, err := c.Ctx.Request.FormFile(fromfile) if err != nil { @@ -609,6 +664,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error { } // StartSession starts session and load old session data info this controller. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) StartSession() session.Store { if c.CruSession == nil { c.CruSession = c.Ctx.Input.CruSession @@ -617,6 +673,7 @@ func (c *Controller) StartSession() session.Store { } // SetSession puts value into session. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) SetSession(name interface{}, value interface{}) { if c.CruSession == nil { c.StartSession() @@ -625,6 +682,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) { } // GetSession gets value from session. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetSession(name interface{}) interface{} { if c.CruSession == nil { c.StartSession() @@ -633,6 +691,7 @@ func (c *Controller) GetSession(name interface{}) interface{} { } // DelSession removes value from session. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) DelSession(name interface{}) { if c.CruSession == nil { c.StartSession() @@ -642,6 +701,7 @@ func (c *Controller) DelSession(name interface{}) { // SessionRegenerateID regenerates session id for this session. // the session data have no changes. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) SessionRegenerateID() { if c.CruSession != nil { c.CruSession.SessionRelease(c.Ctx.ResponseWriter) @@ -651,6 +711,7 @@ func (c *Controller) SessionRegenerateID() { } // DestroySession cleans session data and session cookie. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) DestroySession() { c.Ctx.Input.CruSession.Flush() c.Ctx.Input.CruSession = nil @@ -658,21 +719,25 @@ func (c *Controller) DestroySession() { } // IsAjax returns this request is ajax or not. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) IsAjax() bool { return c.Ctx.Input.IsAjax() } // GetSecureCookie returns decoded cookie value from encoded browser cookie values. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { return c.Ctx.GetSecureCookie(Secret, key) } // SetSecureCookie puts value into cookie after encoded the value. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { c.Ctx.SetSecureCookie(Secret, name, value, others...) } // XSRFToken creates a CSRF token string and returns. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) XSRFToken() string { if c._xsrfToken == "" { expire := int64(BConfig.WebConfig.XSRFExpire) @@ -687,6 +752,7 @@ func (c *Controller) XSRFToken() string { // CheckXSRFCookie checks xsrf token in this request is valid or not. // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // or in form field value named as "_xsrf". +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) CheckXSRFCookie() bool { if !c.EnableXSRF { return true @@ -695,12 +761,14 @@ func (c *Controller) CheckXSRFCookie() bool { } // XSRFFormHTML writes an input field contains xsrf token value. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) XSRFFormHTML() string { return `` } // GetControllerAndAction gets the executing controller name and action name. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *Controller) GetControllerAndAction() (string, string) { return c.controllerName, c.actionName } diff --git a/doc.go b/doc.go index 8825bd29..72284c67 100644 --- a/doc.go +++ b/doc.go @@ -13,5 +13,7 @@ beego is inspired by Tornado, Sinatra and Flask with the added benefit of some G } more information: http://beego.me + +Deprecated: using pkg/, we will delete this in v2.1.0 */ package beego diff --git a/error.go b/error.go index f268f723..40eea5fa 100644 --- a/error.go +++ b/error.go @@ -205,6 +205,7 @@ type errorInfo struct { // ErrorMaps holds map of http handlers for each error string. // there is 10 kinds default error(40x and 50x) +// Deprecated: using pkg/, we will delete this in v2.1.0 var ErrorMaps = make(map[string]*errorInfo, 10) // show 401 unauthorized error. @@ -387,6 +388,7 @@ func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errCont // usage: // beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("500",InternalServerError) +// Deprecated: using pkg/, we will delete this in v2.1.0 func ErrorHandler(code string, h http.HandlerFunc) *App { ErrorMaps[code] = &errorInfo{ errorType: errorTypeHandler, @@ -399,6 +401,7 @@ func ErrorHandler(code string, h http.HandlerFunc) *App { // ErrorController registers ControllerInterface to each http err code string. // usage: // beego.ErrorController(&controllers.ErrorController{}) +// Deprecated: using pkg/, we will delete this in v2.1.0 func ErrorController(c ControllerInterface) *App { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() @@ -418,6 +421,7 @@ func ErrorController(c ControllerInterface) *App { } // Exception Write HttpStatus with errCode and Exec error handler if exist. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Exception(errCode uint64, ctx *context.Context) { exception(strconv.FormatUint(errCode, 10), ctx) } diff --git a/filter.go b/filter.go index 9cc6e913..8596d288 100644 --- a/filter.go +++ b/filter.go @@ -17,11 +17,13 @@ package beego import "github.com/astaxie/beego/context" // FilterFunc defines a filter function which is invoked before the controller handler is executed. +// Deprecated: using pkg/, we will delete this in v2.1.0 type FilterFunc func(*context.Context) // FilterRouter defines a filter operation which is invoked before the controller handler is executed. // It can match the URL against a pattern, and execute a filter function // when a request with a matching URL arrives. +// Deprecated: using pkg/, we will delete this in v2.1.0 type FilterRouter struct { filterFunc FilterFunc tree *Tree @@ -33,6 +35,7 @@ type FilterRouter struct { // ValidRouter checks if the current request is matched by this filter. // If the request is matched, the values of the URL parameters defined // by the filter pattern are also returned. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { isOk := f.tree.Match(url, ctx) if isOk != nil { diff --git a/flash.go b/flash.go index a6485a17..fe3fb974 100644 --- a/flash.go +++ b/flash.go @@ -21,11 +21,13 @@ import ( ) // FlashData is a tools to maintain data when using across request. +// Deprecated: using pkg/, we will delete this in v2.1.0 type FlashData struct { Data map[string]string } // NewFlash return a new empty FlashData struct. +// Deprecated: using pkg/, we will delete this in v2.1.0 func NewFlash() *FlashData { return &FlashData{ Data: make(map[string]string), @@ -33,6 +35,7 @@ func NewFlash() *FlashData { } // Set message to flash +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Set(key string, msg string, args ...interface{}) { if len(args) == 0 { fd.Data[key] = msg @@ -42,6 +45,7 @@ func (fd *FlashData) Set(key string, msg string, args ...interface{}) { } // Success writes success message to flash. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Success(msg string, args ...interface{}) { if len(args) == 0 { fd.Data["success"] = msg @@ -51,6 +55,7 @@ func (fd *FlashData) Success(msg string, args ...interface{}) { } // Notice writes notice message to flash. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Notice(msg string, args ...interface{}) { if len(args) == 0 { fd.Data["notice"] = msg @@ -60,6 +65,7 @@ func (fd *FlashData) Notice(msg string, args ...interface{}) { } // Warning writes warning message to flash. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Warning(msg string, args ...interface{}) { if len(args) == 0 { fd.Data["warning"] = msg @@ -69,6 +75,7 @@ func (fd *FlashData) Warning(msg string, args ...interface{}) { } // Error writes error message to flash. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Error(msg string, args ...interface{}) { if len(args) == 0 { fd.Data["error"] = msg @@ -79,6 +86,7 @@ func (fd *FlashData) Error(msg string, args ...interface{}) { // Store does the saving operation of flash data. // the data are encoded and saved in cookie. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (fd *FlashData) Store(c *Controller) { c.Data["flash"] = fd.Data var flashValue string @@ -89,6 +97,7 @@ func (fd *FlashData) Store(c *Controller) { } // ReadFromRequest parsed flash data from encoded values in cookie. +// Deprecated: using pkg/, we will delete this in v2.1.0 func ReadFromRequest(c *Controller) *FlashData { flash := NewFlash() if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { diff --git a/fs.go b/fs.go index 41cc6f6e..3300813d 100644 --- a/fs.go +++ b/fs.go @@ -6,9 +6,11 @@ import ( "path/filepath" ) +// Deprecated: using pkg/, we will delete this in v2.1.0 type FileSystem struct { } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (d FileSystem) Open(name string) (http.File, error) { return os.Open(name) } @@ -16,6 +18,7 @@ func (d FileSystem) Open(name string) (http.File, error) { // Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or // directory in the tree, including root. All errors that arise visiting files // and directories are filtered by walkFn. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { f, err := fs.Open(root) diff --git a/namespace.go b/namespace.go index 4952c9d5..a6962994 100644 --- a/namespace.go +++ b/namespace.go @@ -24,15 +24,18 @@ import ( type namespaceCond func(*beecontext.Context) bool // LinkNamespace used as link action +// Deprecated: using pkg/, we will delete this in v2.1.0 type LinkNamespace func(*Namespace) // Namespace is store all the info +// Deprecated: using pkg/, we will delete this in v2.1.0 type Namespace struct { prefix string handlers *ControllerRegister } // NewNamespace get new Namespace +// Deprecated: using pkg/, we will delete this in v2.1.0 func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { ns := &Namespace{ prefix: prefix, @@ -54,6 +57,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { // return false // }) // Cond as the first filter +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Cond(cond namespaceCond) *Namespace { fn := func(ctx *beecontext.Context) { if !cond(ctx) { @@ -83,6 +87,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace { // ctx.Redirect(302, "/login") // } // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { var a int if action == "before" { @@ -98,6 +103,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/astaxie/beego#Router +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { n.handlers.Add(rootpath, c, mappingMethods...) return n @@ -105,6 +111,7 @@ func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethod // AutoRouter same as beego.AutoRouter // refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { n.handlers.AddAuto(c) return n @@ -112,6 +119,7 @@ func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { // AutoPrefix same as beego.AutoPrefix // refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { n.handlers.AddAutoPrefix(prefix, c) return n @@ -119,6 +127,7 @@ func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace // Get same as beego.Get // refer: https://godoc.org/github.com/astaxie/beego#Get +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { n.handlers.Get(rootpath, f) return n @@ -126,6 +135,7 @@ func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { // Post same as beego.Post // refer: https://godoc.org/github.com/astaxie/beego#Post +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { n.handlers.Post(rootpath, f) return n @@ -133,6 +143,7 @@ func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { // Delete same as beego.Delete // refer: https://godoc.org/github.com/astaxie/beego#Delete +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { n.handlers.Delete(rootpath, f) return n @@ -140,6 +151,7 @@ func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { // Put same as beego.Put // refer: https://godoc.org/github.com/astaxie/beego#Put +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { n.handlers.Put(rootpath, f) return n @@ -147,6 +159,7 @@ func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { // Head same as beego.Head // refer: https://godoc.org/github.com/astaxie/beego#Head +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { n.handlers.Head(rootpath, f) return n @@ -154,6 +167,7 @@ func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { // Options same as beego.Options // refer: https://godoc.org/github.com/astaxie/beego#Options +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { n.handlers.Options(rootpath, f) return n @@ -161,6 +175,7 @@ func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { // Patch same as beego.Patch // refer: https://godoc.org/github.com/astaxie/beego#Patch +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { n.handlers.Patch(rootpath, f) return n @@ -168,6 +183,7 @@ func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { // Any same as beego.Any // refer: https://godoc.org/github.com/astaxie/beego#Any +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { n.handlers.Any(rootpath, f) return n @@ -175,6 +191,7 @@ func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { // Handler same as beego.Handler // refer: https://godoc.org/github.com/astaxie/beego#Handler +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { n.handlers.Handler(rootpath, h) return n @@ -182,6 +199,7 @@ func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { // Include add include class // refer: https://godoc.org/github.com/astaxie/beego#Include +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { n.handlers.Include(cList...) return n @@ -204,6 +222,7 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { // ctx.Output.Body([]byte("crminfo")) // }), //) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { for _, ni := range ns { for k, v := range ni.handlers.routers { @@ -233,6 +252,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { // AddNamespace register Namespace into beego.Handler // support multi Namespace +// Deprecated: using pkg/, we will delete this in v2.1.0 func AddNamespace(nl ...*Namespace) { for _, n := range nl { for k, v := range n.handlers.routers { @@ -276,6 +296,7 @@ func addPrefix(t *Tree, prefix string) { } // NSCond is Namespace Condition +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSCond(cond namespaceCond) LinkNamespace { return func(ns *Namespace) { ns.Cond(cond) @@ -283,6 +304,7 @@ func NSCond(cond namespaceCond) LinkNamespace { } // NSBefore Namespace BeforeRouter filter +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSBefore(filterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Filter("before", filterList...) @@ -290,6 +312,7 @@ func NSBefore(filterList ...FilterFunc) LinkNamespace { } // NSAfter add Namespace FinishRouter filter +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSAfter(filterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Filter("after", filterList...) @@ -297,6 +320,7 @@ func NSAfter(filterList ...FilterFunc) LinkNamespace { } // NSInclude Namespace Include ControllerInterface +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSInclude(cList ...ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.Include(cList...) @@ -304,6 +328,7 @@ func NSInclude(cList ...ControllerInterface) LinkNamespace { } // NSRouter call Namespace Router +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { return func(ns *Namespace) { ns.Router(rootpath, c, mappingMethods...) @@ -311,6 +336,7 @@ func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) } // NSGet call Namespace Get +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSGet(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Get(rootpath, f) @@ -318,6 +344,7 @@ func NSGet(rootpath string, f FilterFunc) LinkNamespace { } // NSPost call Namespace Post +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSPost(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Post(rootpath, f) @@ -325,6 +352,7 @@ func NSPost(rootpath string, f FilterFunc) LinkNamespace { } // NSHead call Namespace Head +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSHead(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Head(rootpath, f) @@ -332,6 +360,7 @@ func NSHead(rootpath string, f FilterFunc) LinkNamespace { } // NSPut call Namespace Put +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSPut(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Put(rootpath, f) @@ -339,6 +368,7 @@ func NSPut(rootpath string, f FilterFunc) LinkNamespace { } // NSDelete call Namespace Delete +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSDelete(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Delete(rootpath, f) @@ -346,6 +376,7 @@ func NSDelete(rootpath string, f FilterFunc) LinkNamespace { } // NSAny call Namespace Any +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSAny(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Any(rootpath, f) @@ -353,6 +384,7 @@ func NSAny(rootpath string, f FilterFunc) LinkNamespace { } // NSOptions call Namespace Options +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSOptions(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Options(rootpath, f) @@ -360,6 +392,7 @@ func NSOptions(rootpath string, f FilterFunc) LinkNamespace { } // NSPatch call Namespace Patch +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSPatch(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Patch(rootpath, f) @@ -367,6 +400,7 @@ func NSPatch(rootpath string, f FilterFunc) LinkNamespace { } // NSAutoRouter call Namespace AutoRouter +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.AutoRouter(c) @@ -374,6 +408,7 @@ func NSAutoRouter(c ControllerInterface) LinkNamespace { } // NSAutoPrefix call Namespace AutoPrefix +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.AutoPrefix(prefix, c) @@ -381,6 +416,7 @@ func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { } // NSNamespace add sub Namespace +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { return func(ns *Namespace) { n := NewNamespace(prefix, params...) @@ -389,6 +425,7 @@ func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { } // NSHandler add handler +// Deprecated: using pkg/, we will delete this in v2.1.0 func NSHandler(rootpath string, h http.Handler) LinkNamespace { return func(ns *Namespace) { ns.Handler(rootpath, h) diff --git a/policy.go b/policy.go index ab23f927..358a0539 100644 --- a/policy.go +++ b/policy.go @@ -21,9 +21,11 @@ import ( ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. +// Deprecated: using pkg/, we will delete this in v2.1.0 type PolicyFunc func(*context.Context) // FindPolicy Find Router info for URL +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { var urlPath = cont.Input.URL() if !BConfig.RouterCaseSensitive { @@ -72,6 +74,7 @@ func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc } // Policy Register new policy in beego +// Deprecated: using pkg/, we will delete this in v2.1.0 func Policy(pattern, method string, policy ...PolicyFunc) { BeeApp.Handlers.addToPolicy(method, pattern, policy...) } diff --git a/router.go b/router.go index 92316480..1be495ab 100644 --- a/router.go +++ b/router.go @@ -51,6 +51,7 @@ const ( var ( // HTTPMETHOD list the supported http methods. + // Deprecated: using pkg/, we will delete this in v2.1.0 HTTPMETHOD = map[string]bool{ "GET": true, "POST": true, @@ -80,10 +81,12 @@ var ( urlPlaceholder = "{{placeholder}}" // DefaultAccessLogFilter will skip the accesslog if return true + // Deprecated: using pkg/, we will delete this in v2.1.0 DefaultAccessLogFilter FilterHandler = &logFilter{} ) // FilterHandler is an interface for +// Deprecated: using pkg/, we will delete this in v2.1.0 type FilterHandler interface { Filter(*beecontext.Context) bool } @@ -92,6 +95,7 @@ type FilterHandler interface { type logFilter struct { } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (l *logFilter) Filter(ctx *beecontext.Context) bool { requestPath := path.Clean(ctx.Request.URL.Path) if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { @@ -106,6 +110,7 @@ func (l *logFilter) Filter(ctx *beecontext.Context) bool { } // ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +// Deprecated: using pkg/, we will delete this in v2.1.0 func ExceptMethodAppend(action string) { exceptMethod = append(exceptMethod, action) } @@ -122,11 +127,13 @@ type ControllerInfo struct { methodParams []*param.MethodParam } +// Deprecated: using pkg/, we will delete this in v2.1.0 func (c *ControllerInfo) GetPattern() string { return c.pattern } // ControllerRegister containers registered router rules, controller handlers and filters. +// Deprecated: using pkg/, we will delete this in v2.1.0 type ControllerRegister struct { routers map[string]*Tree enablePolicy bool @@ -137,6 +144,7 @@ type ControllerRegister struct { } // NewControllerRegister returns a new ControllerRegister. +// Deprecated: using pkg/, we will delete this in v2.1.0 func NewControllerRegister() *ControllerRegister { return &ControllerRegister{ routers: make(map[string]*Tree), @@ -159,6 +167,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { p.addWithMethodParams(pattern, c, nil, mappingMethods...) } @@ -251,6 +260,7 @@ func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerIn // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller // Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Include(cList ...ControllerInterface) { if BConfig.RunMode == DEV { skip := make(map[string]bool, 10) @@ -313,11 +323,13 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { // ctx := p.GetContext() // ctx.Reset(w, q) // defer p.GiveBackContext(ctx) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) GetContext() *beecontext.Context { return p.pool.Get().(*beecontext.Context) } // GiveBackContext put the ctx into pool so that it could be reuse +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { // clear input cached data ctx.Input.Clear() @@ -331,6 +343,7 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { // Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Get(pattern string, f FilterFunc) { p.AddMethod("get", pattern, f) } @@ -340,6 +353,7 @@ func (p *ControllerRegister) Get(pattern string, f FilterFunc) { // Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Post(pattern string, f FilterFunc) { p.AddMethod("post", pattern, f) } @@ -349,6 +363,7 @@ func (p *ControllerRegister) Post(pattern string, f FilterFunc) { // Put("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Put(pattern string, f FilterFunc) { p.AddMethod("put", pattern, f) } @@ -358,6 +373,7 @@ func (p *ControllerRegister) Put(pattern string, f FilterFunc) { // Delete("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { p.AddMethod("delete", pattern, f) } @@ -367,6 +383,7 @@ func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { // Head("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Head(pattern string, f FilterFunc) { p.AddMethod("head", pattern, f) } @@ -376,6 +393,7 @@ func (p *ControllerRegister) Head(pattern string, f FilterFunc) { // Patch("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { p.AddMethod("patch", pattern, f) } @@ -385,6 +403,7 @@ func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { // Options("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Options(pattern string, f FilterFunc) { p.AddMethod("options", pattern, f) } @@ -394,6 +413,7 @@ func (p *ControllerRegister) Options(pattern string, f FilterFunc) { // Any("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Any(pattern string, f FilterFunc) { p.AddMethod("*", pattern, f) } @@ -403,6 +423,7 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) { // AddMethod("get","/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { method = strings.ToUpper(method) if method != "*" && !HTTPMETHOD[method] { @@ -433,6 +454,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { } // Handler add user defined Handler +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { route := &ControllerInfo{} route.pattern = pattern @@ -453,6 +475,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... // MainController has method List and Page. // visit the url /main/list to execute List function // /main/page to execute Page function. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) AddAuto(c ControllerInterface) { p.AddAutoPrefix("/", c) } @@ -462,6 +485,7 @@ func (p *ControllerRegister) AddAuto(c ControllerInterface) { // MainController has method List and Page. // visit the url /admin/main/list to execute List function // /admin/main/page to execute Page function. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() @@ -492,6 +516,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { mr := &FilterRouter{ tree: NewTree(), @@ -526,6 +551,7 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err // URLFor does another controller handler in this request function. // it can access any controller method. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { paths := strings.Split(endpoint, ".") if len(paths) <= 1 { @@ -695,6 +721,7 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str } // Implement http.Handler interface. +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { startTime := time.Now() var ( @@ -993,6 +1020,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex } // FindRouter Find Router info for URL +// Deprecated: using pkg/, we will delete this in v2.1.0 func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { var urlPath = context.Input.URL() if !BConfig.RouterCaseSensitive { @@ -1020,6 +1048,7 @@ func toURL(params map[string]string) string { } // LogAccess logging info HTTP Access +// Deprecated: using pkg/, we will delete this in v2.1.0 func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { // Skip logging if AccessLogs config is false if !BConfig.Log.AccessLogs { diff --git a/template.go b/template.go index 59875be7..69b178ca 100644 --- a/template.go +++ b/template.go @@ -47,6 +47,7 @@ var ( // ExecuteTemplate applies the template with name to the specified data object, // writing the output to wr. // A template will be executed safely in parallel. +// Deprecated: using pkg/, we will delete this in v2.1.0 func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) } @@ -54,6 +55,7 @@ func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { // ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, // writing the output to wr. // A template will be executed safely in parallel. +// Deprecated: using pkg/, we will delete this in v2.1.0 func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { if BConfig.RunMode == DEV { templatesLock.RLock() @@ -143,6 +145,7 @@ func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { } // HasTemplateExt return this path contains supported template extension of beego or not. +// Deprecated: using pkg/, we will delete this in v2.1.0 func HasTemplateExt(paths string) bool { for _, v := range beeTemplateExt { if strings.HasSuffix(paths, "."+v) { @@ -153,6 +156,7 @@ func HasTemplateExt(paths string) bool { } // AddTemplateExt add new extension for template. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AddTemplateExt(ext string) { for _, v := range beeTemplateExt { if v == ext { @@ -165,6 +169,7 @@ func AddTemplateExt(ext string) { // AddViewPath adds a new path to the supported view paths. //Can later be used by setting a controller ViewPath to this folder //will panic if called after beego.Run() +// Deprecated: using pkg/, we will delete this in v2.1.0 func AddViewPath(viewPath string) error { if beeViewPathTemplateLocked { if _, exist := beeViewPathTemplates[viewPath]; exist { @@ -182,6 +187,7 @@ func lockViewPaths() { // BuildTemplate will build all template files in a directory. // it makes beego can render any template file in view directory. +// Deprecated: using pkg/, we will delete this in v2.1.0 func BuildTemplate(dir string, files ...string) error { var err error fs := beeTemplateFS() @@ -363,11 +369,13 @@ func defaultFSFunc() http.FileSystem { } // SetTemplateFSFunc set default filesystem function +// Deprecated: using pkg/, we will delete this in v2.1.0 func SetTemplateFSFunc(fnt templateFSFunc) { beeTemplateFS = fnt } // SetViewsPath sets view directory path in beego application. +// Deprecated: using pkg/, we will delete this in v2.1.0 func SetViewsPath(path string) *App { BConfig.WebConfig.ViewsPath = path return BeeApp @@ -375,6 +383,7 @@ func SetViewsPath(path string) *App { // SetStaticPath sets static directory path and proper url pattern in beego application. // if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +// Deprecated: using pkg/, we will delete this in v2.1.0 func SetStaticPath(url string, path string) *App { if !strings.HasPrefix(url, "/") { url = "/" + url @@ -387,6 +396,7 @@ func SetStaticPath(url string, path string) *App { } // DelStaticPath removes the static folder setting in this url pattern in beego application. +// Deprecated: using pkg/, we will delete this in v2.1.0 func DelStaticPath(url string) *App { if !strings.HasPrefix(url, "/") { url = "/" + url @@ -399,6 +409,7 @@ func DelStaticPath(url string) *App { } // AddTemplateEngine add a new templatePreProcessor which support extension +// Deprecated: using pkg/, we will delete this in v2.1.0 func AddTemplateEngine(extension string, fn templatePreProcessor) *App { AddTemplateExt(extension) beeTemplateEngines[extension] = fn diff --git a/templatefunc.go b/templatefunc.go index 6f02b8d6..9e7c42fc 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -35,6 +35,7 @@ const ( ) // Substr returns the substr from start to length. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Substr(s string, start, length int) string { bt := []rune(s) if start < 0 { @@ -53,6 +54,7 @@ func Substr(s string, start, length int) string { } // HTML2str returns escaping text convert from html. +// Deprecated: using pkg/, we will delete this in v2.1.0 func HTML2str(html string) string { re := regexp.MustCompile(`\<[\S\s]+?\>`) @@ -76,6 +78,7 @@ func HTML2str(html string) string { } // DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +// Deprecated: using pkg/, we will delete this in v2.1.0 func DateFormat(t time.Time, layout string) (datestring string) { datestring = t.Format(layout) return @@ -123,6 +126,7 @@ var datePatterns = []string{ } // DateParse Parse Date use PHP time format. +// Deprecated: using pkg/, we will delete this in v2.1.0 func DateParse(dateString, format string) (time.Time, error) { replacer := strings.NewReplacer(datePatterns...) format = replacer.Replace(format) @@ -130,6 +134,7 @@ func DateParse(dateString, format string) (time.Time, error) { } // Date takes a PHP like date func to Go's time format. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Date(t time.Time, format string) string { replacer := strings.NewReplacer(datePatterns...) format = replacer.Replace(format) @@ -138,6 +143,7 @@ func Date(t time.Time, format string) string { // Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. // Whitespace is trimmed. Used by the template parser as "eq". +// Deprecated: using pkg/, we will delete this in v2.1.0 func Compare(a, b interface{}) (equal bool) { equal = false if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { @@ -147,16 +153,19 @@ func Compare(a, b interface{}) (equal bool) { } // CompareNot !Compare +// Deprecated: using pkg/, we will delete this in v2.1.0 func CompareNot(a, b interface{}) (equal bool) { return !Compare(a, b) } // NotNil the same as CompareNot +// Deprecated: using pkg/, we will delete this in v2.1.0 func NotNil(a interface{}) (isNil bool) { return CompareNot(a, nil) } // GetConfig get the Appconfig +// Deprecated: using pkg/, we will delete this in v2.1.0 func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": @@ -195,11 +204,13 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{ } // Str2html Convert string to template.HTML type. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Str2html(raw string) template.HTML { return template.HTML(raw) } // Htmlquote returns quoted html string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Htmlquote(text string) string { //HTML编码为实体符号 /* @@ -219,6 +230,7 @@ func Htmlquote(text string) string { } // Htmlunquote returns unquoted html string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func Htmlunquote(text string) string { //实体符号解释为HTML /* @@ -249,11 +261,13 @@ func Htmlunquote(text string) string { // /user/John%20Doe // // more detail http://beego.me/docs/mvc/controller/urlbuilding.md +// Deprecated: using pkg/, we will delete this in v2.1.0 func URLFor(endpoint string, values ...interface{}) string { return BeeApp.Handlers.URLFor(endpoint, values...) } // AssetsJs returns script tag with src string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AssetsJs(text string) template.HTML { text = "" @@ -262,6 +276,7 @@ func AssetsJs(text string) template.HTML { } // AssetsCSS returns stylesheet link tag with src string. +// Deprecated: using pkg/, we will delete this in v2.1.0 func AssetsCSS(text string) template.HTML { text = "" @@ -411,6 +426,7 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e } // ParseForm will parse form values to struct via tag. +// Deprecated: using pkg/, we will delete this in v2.1.0 func ParseForm(form url.Values, obj interface{}) error { objT := reflect.TypeOf(obj) objV := reflect.ValueOf(obj) @@ -442,6 +458,7 @@ var unKind = map[reflect.Kind]bool{ // RenderForm will render object to form html. // obj must be a struct pointer. +// Deprecated: using pkg/, we will delete this in v2.1.0 func RenderForm(obj interface{}) template.HTML { objT := reflect.TypeOf(obj) objV := reflect.ValueOf(obj) @@ -715,6 +732,7 @@ func ge(arg1, arg2 interface{}) (bool, error) { // // {{ map_get m "a" }} // return 1 // {{ map_get m 1 "c" }} // return 4 +// Deprecated: using pkg/, we will delete this in v2.1.0 func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { arg1Type := reflect.TypeOf(arg1) arg1Val := reflect.ValueOf(arg1) diff --git a/tree.go b/tree.go index 9e53003b..7fa3a7cb 100644 --- a/tree.go +++ b/tree.go @@ -31,6 +31,7 @@ var ( // fixRouter stores Fixed Router // wildcard stores params // leaves store the endpoint information +// Deprecated: using pkg/, we will delete this in v2.1.0 type Tree struct { //prefix set for static router prefix string @@ -43,12 +44,14 @@ type Tree struct { } // NewTree return a new Tree +// Deprecated: using pkg/, we will delete this in v2.1.0 func NewTree() *Tree { return &Tree{} } // AddTree will add tree to the exist Tree // prefix should has no params +// Deprecated: using pkg/, we will delete this in v2.1.0 func (t *Tree) AddTree(prefix string, tree *Tree) { t.addtree(splitPath(prefix), tree, nil, "") } @@ -200,6 +203,7 @@ func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { } // AddRouter call addseg function +// Deprecated: using pkg/, we will delete this in v2.1.0 func (t *Tree) AddRouter(pattern string, runObject interface{}) { t.addseg(splitPath(pattern), runObject, nil, "") } @@ -283,6 +287,7 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, } // Match router to runObject & params +// Deprecated: using pkg/, we will delete this in v2.1.0 func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { if len(pattern) == 0 || pattern[0] != '/' { return nil From 5f2f6e4f86c36cf91ab11b227ca1d13a799c88b1 Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Wed, 5 Aug 2020 18:29:22 +0200 Subject: [PATCH 067/207] Add interface change in pkg folder --- pkg/session/couchbase/sess_couchbase.go | 6 +- pkg/session/ledis/ledis_session.go | 4 +- pkg/session/memcache/sess_memcache.go | 8 +-- pkg/session/mysql/sess_mysql.go | 10 ++- pkg/session/postgres/sess_postgresql.go | 10 ++- pkg/session/redis/sess_redis.go | 6 +- pkg/session/redis_cluster/redis_cluster.go | 6 +- .../redis_sentinel/sess_redis_sentinel.go | 6 +- pkg/session/sess_cookie.go | 4 +- pkg/session/sess_file.go | 8 +-- pkg/session/sess_file_test.go | 62 +++++++++++++++---- pkg/session/sess_mem.go | 6 +- pkg/session/session.go | 12 +++- pkg/session/ssdb/sess_ssdb.go | 8 +-- 14 files changed, 107 insertions(+), 49 deletions(-) diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go index 227c0bc6..b824a938 100644 --- a/pkg/session/couchbase/sess_couchbase.go +++ b/pkg/session/couchbase/sess_couchbase.go @@ -179,16 +179,16 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) bool { +func (cp *Provider) SessionExist(sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() var doc []byte if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false + return false, err } - return true + return true, nil } // SessionRegenerate remove oldsid and use sid to generate new session diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go index a0988327..e43d70a0 100644 --- a/pkg/session/ledis/ledis_session.go +++ b/pkg/session/ledis/ledis_session.go @@ -132,9 +132,9 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) bool { +func (lp *Provider) SessionExist(sid string) (bool, error) { count, _ := c.Exists([]byte(sid)) - return count != 0 + return count != 0, nil } // SessionRegenerate generate new sid for ledis session diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go index 6cd8acab..7fab842a 100644 --- a/pkg/session/memcache/sess_memcache.go +++ b/pkg/session/memcache/sess_memcache.go @@ -149,16 +149,16 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) bool { +func (rp *MemProvider) SessionExist(sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { - return false + return false, err } } if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for memcache session diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go index 73738496..c641a4bf 100644 --- a/pkg/session/mysql/sess_mysql.go +++ b/pkg/session/mysql/sess_mysql.go @@ -164,13 +164,19 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for mysql session diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go index e6c9ed89..688c0e36 100644 --- a/pkg/session/postgres/sess_postgresql.go +++ b/pkg/session/postgres/sess_postgresql.go @@ -178,13 +178,19 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for postgresql session diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index f569f9dd..cf2dddf4 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -211,14 +211,14 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist.Get() defer c.Close() if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis session diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index f7fc7845..8b20ab19 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -176,12 +176,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis_cluster session diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go index 23bebf2a..f78486be 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -189,12 +189,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) bool { +func (rp *Provider) SessionExist(sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for redis_sentinel session diff --git a/pkg/session/sess_cookie.go b/pkg/session/sess_cookie.go index 6ad5debc..30a7032e 100644 --- a/pkg/session/sess_cookie.go +++ b/pkg/session/sess_cookie.go @@ -147,8 +147,8 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) bool { - return true +func (pder *CookieProvider) SessionExist(sid string) (bool, error) { + return true, nil } // SessionRegenerate Implement method, no used. diff --git a/pkg/session/sess_file.go b/pkg/session/sess_file.go index 47ad54a7..37d5bd68 100644 --- a/pkg/session/sess_file.go +++ b/pkg/session/sess_file.go @@ -176,17 +176,17 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) bool { +func (fp *FileProvider) SessionExist(sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() if len(sid) < 2 { - SLogger.Println("min length of session id is 2", sid) - return false + SLogger.Println("min length of session id is 2 but got length: ", sid) + return false, errors.New("min length of session id is 2") } _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return err == nil + return err == nil, nil } // SessionDestroy Remove all files in this save path diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go index 021c43fc..64b8d94a 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/session/sess_file_test.go @@ -56,16 +56,24 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - if fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil{ + t.Error(err) + } + if exists { t.Error() } - _, err := fp.SessionRead(sid) + _, err = fp.SessionRead(sid) if err != nil { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } } @@ -79,15 +87,27 @@ func TestFileProvider_SessionExist2(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - if fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } - if fp.SessionExist("") { + exists, err = fp.SessionExist("") + if err == nil { + t.Error() + } + if exists { t.Error() } - if fp.SessionExist("1") { + exists, err = fp.SessionExist("1") + if err == nil { + t.Error() + } + if exists { t.Error() } } @@ -171,7 +191,11 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } @@ -180,11 +204,19 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - if fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } - if !fp.SessionExist(sidNew) { + exists, err = fp.SessionExist(sidNew) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } } @@ -203,7 +235,11 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - if !fp.SessionExist(sid) { + exists, err := fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if !exists { t.Error() } @@ -212,7 +248,11 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - if fp.SessionExist(sid) { + exists, err = fp.SessionExist(sid) + if err != nil { + t.Error(err) + } + if exists { t.Error() } } diff --git a/pkg/session/sess_mem.go b/pkg/session/sess_mem.go index 64d8b056..bd69ff80 100644 --- a/pkg/session/sess_mem.go +++ b/pkg/session/sess_mem.go @@ -109,13 +109,13 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) bool { +func (pder *MemProvider) SessionExist(sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { - return true + return true, nil } - return false + return false, nil } // SessionRegenerate generate new sid for session store in memory session diff --git a/pkg/session/session.go b/pkg/session/session.go index eb85360a..92e35de4 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -56,7 +56,7 @@ type Store interface { type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (Store, error) - SessionExist(sid string) bool + SessionExist(sid string) (bool, error) SessionRegenerate(oldsid, sid string) (Store, error) SessionDestroy(sid string) error SessionAll() int //get all active session @@ -211,8 +211,14 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - if sid != "" && manager.provider.SessionExist(sid) { - return manager.provider.SessionRead(sid) + if sid != "" { + exists, err := manager.provider.SessionExist(sid) + if err != nil { + return nil, err + } + if exists { + return manager.provider.SessionRead(sid) + } } // Generate a new session diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go index 1b382954..f950b835 100644 --- a/pkg/session/ssdb/sess_ssdb.go +++ b/pkg/session/ssdb/sess_ssdb.go @@ -68,10 +68,10 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) bool { +func (p *Provider) SessionExist(sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { - panic(err) + return false, err } } value, err := p.client.Get(sid) @@ -79,9 +79,9 @@ func (p *Provider) SessionExist(sid string) bool { panic(err) } if value == nil || len(value.(string)) == 0 { - return false + return false, nil } - return true + return true, nil } // SessionRegenerate regenerate session with new sid and delete oldsid From 3052c64b6c47488c4c4c949629c9fabe59616447 Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Wed, 5 Aug 2020 18:29:47 +0200 Subject: [PATCH 068/207] Revert "Add error to SessionExist interface" This reverts commit 28e6b3b92450b0ca0e9c1342d400f8810e4d5e5a. --- session/couchbase/sess_couchbase.go | 6 +- session/ledis/ledis_session.go | 4 +- session/memcache/sess_memcache.go | 8 +-- session/mysql/sess_mysql.go | 10 +-- session/postgres/sess_postgresql.go | 10 +-- session/redis/sess_redis.go | 6 +- session/redis_cluster/redis_cluster.go | 6 +- session/redis_sentinel/sess_redis_sentinel.go | 6 +- session/sess_cookie.go | 4 +- session/sess_file.go | 9 +-- session/sess_file_test.go | 62 ++++--------------- session/sess_mem.go | 6 +- session/session.go | 12 +--- session/ssdb/sess_ssdb.go | 8 +-- 14 files changed, 48 insertions(+), 109 deletions(-) diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go index 46ab07ab..707d042c 100644 --- a/session/couchbase/sess_couchbase.go +++ b/session/couchbase/sess_couchbase.go @@ -179,16 +179,16 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) (bool, error) { +func (cp *Provider) SessionExist(sid string) bool { cp.b = cp.getBucket() defer cp.b.Close() var doc []byte if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false, err + return false } - return true, nil + return true } // SessionRegenerate remove oldsid and use sid to generate new session diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index 4f578eac..ee81df67 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -132,9 +132,9 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) (bool, error) { +func (lp *Provider) SessionExist(sid string) bool { count, _ := c.Exists([]byte(sid)) - return count != 0, nil + return count != 0 } // SessionRegenerate generate new sid for ledis session diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index e76eb8a5..85a2d815 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -149,16 +149,16 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) (bool, error) { +func (rp *MemProvider) SessionExist(sid string) bool { if client == nil { if err := rp.connectInit(); err != nil { - return false, err + return false } } if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false, err + return false } - return true, nil + return true } // SessionRegenerate generate new sid for memcache session diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 9f9547a7..301353ab 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -164,19 +164,13 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(sid string) bool { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - if err != nil { - if err == sql.ErrNoRows { - return false, nil - } - return false, err - } - return true, nil + return err != sql.ErrNoRows } // SessionRegenerate generate new sid for mysql session diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index d8a1e6de..0b8b9645 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -178,19 +178,13 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(sid string) bool { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - if err != nil { - if err == sql.ErrNoRows { - return false, nil - } - return false, err - } - return true, nil + return err != sql.ErrNoRows } // SessionRegenerate generate new sid for postgresql session diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index 439b14cb..5c382d61 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -211,14 +211,14 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(sid string) bool { c := rp.poollist.Get() defer c.Close() if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false, err + return false } - return true, nil + return true } // SessionRegenerate generate new sid for redis session diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go index d4e28327..262fa2e3 100644 --- a/session/redis_cluster/redis_cluster.go +++ b/session/redis_cluster/redis_cluster.go @@ -176,12 +176,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(sid string) bool { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false, err + return false } - return true, nil + return true } // SessionRegenerate generate new sid for redis_cluster session diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go index eead7a74..6ecb2977 100644 --- a/session/redis_sentinel/sess_redis_sentinel.go +++ b/session/redis_sentinel/sess_redis_sentinel.go @@ -189,12 +189,12 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(sid string) bool { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false, err + return false } - return true, nil + return true } // SessionRegenerate generate new sid for redis_sentinel session diff --git a/session/sess_cookie.go b/session/sess_cookie.go index 30a7032e..6ad5debc 100644 --- a/session/sess_cookie.go +++ b/session/sess_cookie.go @@ -147,8 +147,8 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) (bool, error) { - return true, nil +func (pder *CookieProvider) SessionExist(sid string) bool { + return true } // SessionRegenerate Implement method, no used. diff --git a/session/sess_file.go b/session/sess_file.go index 3345d5d0..47ad54a7 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -176,20 +176,17 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) (bool, error) { +func (fp *FileProvider) SessionExist(sid string) bool { filepder.lock.Lock() defer filepder.lock.Unlock() if len(sid) < 2 { SLogger.Println("min length of session id is 2", sid) - return false, nil + return false } _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - if err != nil { - return false, nil - } - return true, nil + return err == nil } // SessionDestroy Remove all files in this save path diff --git a/session/sess_file_test.go b/session/sess_file_test.go index 1e155f91..021c43fc 100644 --- a/session/sess_file_test.go +++ b/session/sess_file_test.go @@ -56,24 +56,16 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - exists, err := fp.SessionExist(sid) - if err != nil{ - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - _, err = fp.SessionRead(sid) + _, err := fp.SessionRead(sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } } @@ -87,27 +79,15 @@ func TestFileProvider_SessionExist2(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - exists, err = fp.SessionExist("") - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist("") { t.Error() } - exists, err = fp.SessionExist("1") - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist("1") { t.Error() } } @@ -191,11 +171,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } @@ -204,19 +180,11 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - exists, err = fp.SessionExist(sidNew) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sidNew) { t.Error() } } @@ -235,11 +203,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } @@ -248,11 +212,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } } diff --git a/session/sess_mem.go b/session/sess_mem.go index bd69ff80..64d8b056 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -109,13 +109,13 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) (bool, error) { +func (pder *MemProvider) SessionExist(sid string) bool { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { - return true, nil + return true } - return false, nil + return false } // SessionRegenerate generate new sid for session store in memory session diff --git a/session/session.go b/session/session.go index 92e35de4..eb85360a 100644 --- a/session/session.go +++ b/session/session.go @@ -56,7 +56,7 @@ type Store interface { type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (Store, error) - SessionExist(sid string) (bool, error) + SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (Store, error) SessionDestroy(sid string) error SessionAll() int //get all active session @@ -211,14 +211,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - if sid != "" { - exists, err := manager.provider.SessionExist(sid) - if err != nil { - return nil, err - } - if exists { - return manager.provider.SessionRead(sid) - } + if sid != "" && manager.provider.SessionExist(sid) { + return manager.provider.SessionRead(sid) } // Generate a new session diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go index 9b9eee94..de0c6360 100644 --- a/session/ssdb/sess_ssdb.go +++ b/session/ssdb/sess_ssdb.go @@ -68,7 +68,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) (bool, error) { +func (p *Provider) SessionExist(sid string) bool { if p.client == nil { if err := p.connectInit(); err != nil { panic(err) @@ -76,12 +76,12 @@ func (p *Provider) SessionExist(sid string) (bool, error) { } value, err := p.client.Get(sid) if err != nil { - return false, err + panic(err) } if value == nil || len(value.(string)) == 0 { - return false, nil + return false } - return true, nil + return true } // SessionRegenerate regenerate session with new sid and delete oldsid From 009074725eb25fefb59df82711ab4a7edb1eaa1b Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Wed, 5 Aug 2020 18:32:33 +0200 Subject: [PATCH 069/207] Move interface change to pkg/session/README.md --- pkg/session/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/session/README.md b/pkg/session/README.md index 6d0a297e..a5c3bd6d 100644 --- a/pkg/session/README.md +++ b/pkg/session/README.md @@ -101,7 +101,7 @@ Maybe you will find the **memory** provider is a good example. type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) bool + SessionExist(sid string) (bool, error) SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionDestroy(sid string) error SessionAll() int //get all active session From 5c8c088684f7aa2070bbcf234f4fe4e103c16157 Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Wed, 5 Aug 2020 18:33:17 +0200 Subject: [PATCH 070/207] Revert "Change interface in session README" This reverts commit 6f5c5bd3a65561db56aca26eae4a50abef8fa5b4. --- session/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session/README.md b/session/README.md index a5c3bd6d..6d0a297e 100644 --- a/session/README.md +++ b/session/README.md @@ -101,7 +101,7 @@ Maybe you will find the **memory** provider is a good example. type Provider interface { SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) (bool, error) + SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionDestroy(sid string) error SessionAll() int //get all active session From 2f5683610f42fdd046e0099fc59e0fbc03280f29 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 5 Aug 2020 17:44:39 +0100 Subject: [PATCH 071/207] Minor grammar fixes --- pkg/cache/conv.go | 10 ++++---- pkg/cache/file.go | 47 +++++++++++++++++----------------- pkg/cache/memory.go | 34 ++++++++++++------------ pkg/common/kv.go | 16 ++++++------ pkg/config/env/env.go | 2 +- pkg/orm/cmd.go | 16 ++++++------ pkg/plugins/apiauth/apiauth.go | 28 ++++++++++---------- 7 files changed, 77 insertions(+), 76 deletions(-) diff --git a/pkg/cache/conv.go b/pkg/cache/conv.go index 87800586..158f7f41 100644 --- a/pkg/cache/conv.go +++ b/pkg/cache/conv.go @@ -19,7 +19,7 @@ import ( "strconv" ) -// GetString convert interface to string. +// GetString converts interface to string. func GetString(v interface{}) string { switch result := v.(type) { case string: @@ -34,7 +34,7 @@ func GetString(v interface{}) string { return "" } -// GetInt convert interface to int. +// GetInt converts interface to int. func GetInt(v interface{}) int { switch result := v.(type) { case int: @@ -52,7 +52,7 @@ func GetInt(v interface{}) int { return 0 } -// GetInt64 convert interface to int64. +// GetInt64 converts interface to int64. func GetInt64(v interface{}) int64 { switch result := v.(type) { case int: @@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 { return 0 } -// GetFloat64 convert interface to float64. +// GetFloat64 converts interface to float64. func GetFloat64(v interface{}) float64 { switch result := v.(type) { case float64: @@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 { return 0 } -// GetBool convert interface to bool. +// GetBool converts interface to bool. func GetBool(v interface{}) bool { switch result := v.(type) { case bool: diff --git a/pkg/cache/file.go b/pkg/cache/file.go index 6f12d3ee..dcc60bc0 100644 --- a/pkg/cache/file.go +++ b/pkg/cache/file.go @@ -30,8 +30,8 @@ import ( "time" ) -// FileCacheItem is basic unit of file cache adapter. -// it contains data and expire time. +// FileCacheItem is basic unit of file cache adapter which +// contains data and expire time. type FileCacheItem struct { Data interface{} Lastaccess time.Time @@ -54,15 +54,15 @@ type FileCache struct { EmbedExpiry int } -// NewFileCache Create new file cache with no config. -// the level and expiry need set in method StartAndGC as config string. +// NewFileCache cerates a new file cache with no config. +// The level and expiry need to be set in the method StartAndGC as config string. func NewFileCache() Cache { // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} return &FileCache{} } -// StartAndGC will start and begin gc for file cache. -// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +// StartAndGC starts gc for file cache. +// config must be in the format {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} func (fc *FileCache) StartAndGC(config string) error { cfg := make(map[string]string) @@ -91,14 +91,14 @@ func (fc *FileCache) StartAndGC(config string) error { return nil } -// Init will make new dir for file cache if not exist. +// Init makes new a dir for file cache if it does not already exist func (fc *FileCache) Init() { if ok, _ := exists(fc.CachePath); !ok { // todo : error handle _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle } } -// get cached file name. it's md5 encoded. +// getCachedFilename returns an md5 encoded file name. func (fc *FileCache) getCacheFileName(key string) string { m := md5.New() io.WriteString(m, key) @@ -119,7 +119,7 @@ func (fc *FileCache) getCacheFileName(key string) string { } // Get value from file cache. -// if non-exist or expired, return empty string. +// if nonexistent or expired return an empty string. func (fc *FileCache) Get(key string) interface{} { fileData, err := FileGetContents(fc.getCacheFileName(key)) if err != nil { @@ -134,7 +134,7 @@ func (fc *FileCache) Get(key string) interface{} { } // GetMulti gets values from file cache. -// if non-exist or expired, return empty string. +// if nonexistent or expired return an empty string. func (fc *FileCache) GetMulti(keys []string) []interface{} { var rc []interface{} for _, key := range keys { @@ -144,7 +144,7 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} { } // Put value into file cache. -// timeout means how long to keep this file, unit of ms. +// timeout: how long this file should be kept in ms // if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { gob.Register(val) @@ -172,8 +172,8 @@ func (fc *FileCache) Delete(key string) error { return nil } -// Incr will increase cached int value. -// fc value is saving forever unless Delete. +// Incr increases cached int value. +// fc value is saved forever unless deleted. func (fc *FileCache) Incr(key string) error { data := fc.Get(key) var incr int @@ -186,7 +186,7 @@ func (fc *FileCache) Incr(key string) error { return nil } -// Decr will decrease cached int value. +// Decr decreases cached int value. func (fc *FileCache) Decr(key string) error { data := fc.Get(key) var decr int @@ -199,19 +199,18 @@ func (fc *FileCache) Decr(key string) error { return nil } -// IsExist check value is exist. +// IsExist checks if value exists. func (fc *FileCache) IsExist(key string) bool { ret, _ := exists(fc.getCacheFileName(key)) return ret } -// ClearAll will clean cached files. -// not implemented. +// ClearAll cleans cached files (not implemented) func (fc *FileCache) ClearAll() error { return nil } -// check file exist. +// Check if a file exists func exists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { @@ -223,19 +222,19 @@ func exists(path string) (bool, error) { return false, err } -// FileGetContents Get bytes to file. -// if non-exist, create this file. +// FileGetContents Reads bytes from a file. +// if non-existent, create this file. func FileGetContents(filename string) (data []byte, e error) { return ioutil.ReadFile(filename) } -// FilePutContents Put bytes to file. -// if non-exist, create this file. +// FilePutContents puts bytes into a file. +// if non-existent, create this file. func FilePutContents(filename string, content []byte) error { return ioutil.WriteFile(filename, content, os.ModePerm) } -// GobEncode Gob encodes file cache item. +// GobEncode Gob encodes a file cache item. func GobEncode(data interface{}) ([]byte, error) { buf := bytes.NewBuffer(nil) enc := gob.NewEncoder(buf) @@ -246,7 +245,7 @@ func GobEncode(data interface{}) ([]byte, error) { return buf.Bytes(), err } -// GobDecode Gob decodes file cache item. +// GobDecode Gob decodes a file cache item. func GobDecode(data []byte, to *FileCacheItem) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index d8314e3c..792a628a 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -22,11 +22,11 @@ import ( ) var ( - // DefaultEvery means the clock time of recycling the expired cache items in memory. + // Recycle the expired cache items in memory (in seconds) DefaultEvery = 60 // 1 minute ) -// MemoryItem store memory cache item. +// MemoryItem stores memory cache item. type MemoryItem struct { val interface{} createdTime time.Time @@ -41,8 +41,8 @@ func (mi *MemoryItem) isExpire() bool { return time.Now().Sub(mi.createdTime) > mi.lifespan } -// MemoryCache is Memory cache adapter. -// it contains a RW locker for safe map storage. +// MemoryCache is a memory cache adapter. +// Contains a RW locker for safe map storage. type MemoryCache struct { sync.RWMutex dur time.Duration @@ -56,8 +56,8 @@ func NewMemoryCache() Cache { return &cache } -// Get cache from memory. -// if non-existed or expired, return nil. +// Get returns cache from memory. +// If non-existent or expired, return nil. func (bc *MemoryCache) Get(name string) interface{} { bc.RLock() defer bc.RUnlock() @@ -71,7 +71,7 @@ func (bc *MemoryCache) Get(name string) interface{} { } // GetMulti gets caches from memory. -// if non-existed or expired, return nil. +// If non-existent or expired, return nil. func (bc *MemoryCache) GetMulti(names []string) []interface{} { var rc []interface{} for _, name := range names { @@ -80,8 +80,8 @@ func (bc *MemoryCache) GetMulti(names []string) []interface{} { return rc } -// Put cache to memory. -// if lifespan is 0, it will be forever till restart. +// Put puts cache into memory. +// If lifespan is 0, it will never overwrite this value func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { bc.Lock() defer bc.Unlock() @@ -107,8 +107,8 @@ func (bc *MemoryCache) Delete(name string) error { return nil } -// Incr increase cache counter in memory. -// it supports int,int32,int64,uint,uint32,uint64. +// Incr increases cache counter in memory. +// Supports int,int32,int64,uint,uint32,uint64. func (bc *MemoryCache) Incr(key string) error { bc.Lock() defer bc.Unlock() @@ -135,7 +135,7 @@ func (bc *MemoryCache) Incr(key string) error { return nil } -// Decr decrease counter in memory. +// Decr decreases counter in memory. func (bc *MemoryCache) Decr(key string) error { bc.Lock() defer bc.Unlock() @@ -174,7 +174,7 @@ func (bc *MemoryCache) Decr(key string) error { return nil } -// IsExist check cache exist in memory. +// IsExist checks if cache exists in memory. func (bc *MemoryCache) IsExist(name string) bool { bc.RLock() defer bc.RUnlock() @@ -184,7 +184,7 @@ func (bc *MemoryCache) IsExist(name string) bool { return false } -// ClearAll will delete all cache in memory. +// ClearAll deletes all cache in memory. func (bc *MemoryCache) ClearAll() error { bc.Lock() defer bc.Unlock() @@ -192,7 +192,7 @@ func (bc *MemoryCache) ClearAll() error { return nil } -// StartAndGC start memory cache. it will check expiration in every clock time. +// StartAndGC starts memory cache. Checks expiration in every clock time. func (bc *MemoryCache) StartAndGC(config string) error { var cf map[string]int json.Unmarshal([]byte(config), &cf) @@ -230,7 +230,7 @@ func (bc *MemoryCache) vacuum() { } } -// expiredKeys returns key list which are expired. +// expiredKeys returns keys list which are expired. func (bc *MemoryCache) expiredKeys() (keys []string) { bc.RLock() defer bc.RUnlock() @@ -242,7 +242,7 @@ func (bc *MemoryCache) expiredKeys() (keys []string) { return } -// clearItems removes all the items which key in keys. +// ClearItems removes all items who's key is in keys func (bc *MemoryCache) clearItems(keys []string) { bc.Lock() defer bc.Unlock() diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 86a50132..8468f4fe 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -19,8 +19,8 @@ type KV interface { GetValue() interface{} } -// SimpleKV is common structure to store key-value data. -// when you need something like Pair, you can use this +// SimpleKV is common structure to store key-value pairs. +// When you need something like Pair, you can use this type SimpleKV struct { Key interface{} Value interface{} @@ -41,8 +41,8 @@ type KVs struct { kvs map[interface{}]interface{} } -// GetValueOr check whether this contains the key, -// if the key not found, the default value will be return +// GetValueOr returns the value for a given key, if non-existant +// it returns defValue func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { v, ok := kvs.kvs[key] if ok { @@ -51,13 +51,13 @@ func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { return defValue } -// Contains will check whether contains the key +// Contains checks if a key exists func (kvs *KVs) Contains(key interface{}) bool { _, ok := kvs.kvs[key] return ok } -// IfContains is a functional API that if the key is in KVs, the action will be invoked +// IfContains invokes the action on a key if it exists func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { v, ok := kvs.kvs[key] if ok { @@ -66,13 +66,13 @@ func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs return kvs } -// Put store the value +// Put stores the value func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { kvs.kvs[key] = value return kvs } -// NewKVs will create the *KVs instance +// NewKVs creates the *KVs instance func NewKVs(kvs ...KV) *KVs { res := &KVs{ kvs: make(map[interface{}]interface{}, len(kvs)), diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go index 7c729780..5d8e47de 100644 --- a/pkg/config/env/env.go +++ b/pkg/config/env/env.go @@ -34,7 +34,7 @@ func init() { } } -// Get returns a value by key. +// Get returns a value for a given key. // If the key does not exist, the default value will be returned. func Get(key string, defVal string) string { if val := env.Get(key); val != nil { diff --git a/pkg/orm/cmd.go b/pkg/orm/cmd.go index 0ff4dc40..f03382e9 100644 --- a/pkg/orm/cmd.go +++ b/pkg/orm/cmd.go @@ -46,7 +46,7 @@ func printHelp(errs ...string) { os.Exit(2) } -// RunCommand listen for orm command and then run it if command arguments passed. +// RunCommand listens for orm command and runs if command arguments have been passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -83,7 +83,7 @@ type commandSyncDb struct { rtOnError bool } -// parse orm command line arguments. +// Parse the orm command line arguments. func (d *commandSyncDb) Parse(args []string) { var name string @@ -96,7 +96,7 @@ func (d *commandSyncDb) Parse(args []string) { d.al = getDbAlias(name) } -// run orm line command. +// Run orm line command. func (d *commandSyncDb) Run() error { var drops []string if d.force { @@ -232,7 +232,7 @@ type commandSQLAll struct { al *alias } -// parse orm command line arguments. +// Parse orm command line arguments. func (d *commandSQLAll) Parse(args []string) { var name string @@ -243,7 +243,7 @@ func (d *commandSQLAll) Parse(args []string) { d.al = getDbAlias(name) } -// run orm line command. +// Run orm line command. func (d *commandSQLAll) Run() error { sqls, indexes := getDbCreateSQL(d.al) var all []string @@ -266,9 +266,9 @@ func init() { } // RunSyncdb run syncdb command line. -// name means table's alias name. default is "default". -// force means run next sql if the current is error. -// verbose means show all info when running command or not. +// name: Table's alias name (default is "default") +// force: Run the next sql command even if the current gave an error +// verbose: Print all information, useful for debugging func RunSyncdb(name string, force bool, verbose bool) error { BootStrap() diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go index 90360aba..7b1d4405 100644 --- a/pkg/plugins/apiauth/apiauth.go +++ b/pkg/plugins/apiauth/apiauth.go @@ -65,14 +65,14 @@ import ( "sort" "time" - "github.com/astaxie/beego/pkg" + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/context" ) -// AppIDToAppSecret is used to get appsecret throw appid +// AppIDToAppSecret gets appsecret through appid type AppIDToAppSecret func(string) string -// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +// APIBasicAuth uses the basic appid/appkey as the AppIdToAppSecret func APIBasicAuth(appid, appkey string) beego.FilterFunc { ft := func(aid string) string { if aid == appid { @@ -83,56 +83,58 @@ func APIBasicAuth(appid, appkey string) beego.FilterFunc { return APISecretAuth(ft, 300) } -// APIBaiscAuth calls APIBasicAuth for previous callers +// APIBasicAuth calls APIBasicAuth for previous callers func APIBaiscAuth(appid, appkey string) beego.FilterFunc { return APIBasicAuth(appid, appkey) } -// APISecretAuth use AppIdToAppSecret verify and +// APISecretAuth uses AppIdToAppSecret verify and func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { return func(ctx *context.Context) { if ctx.Input.Query("appid") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: appid") + ctx.WriteString("missing query parameter: appid") return } appsecret := f(ctx.Input.Query("appid")) if appsecret == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("not exist this appid") + ctx.WriteString("appid query parameter missing") return } if ctx.Input.Query("signature") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: signature") + ctx.WriteString("missing query parameter: signature") + return } if ctx.Input.Query("timestamp") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: timestamp") + ctx.WriteString("missing query parameter: timestamp") return } u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) if err != nil { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") + ctx.WriteString("incorrect timestamp format. Should be in the form 2006-01-02 15:04:05") + return } t := time.Now() if t.Sub(u).Seconds() > float64(timeout) { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timeout! the request time is long ago, please try again") + ctx.WriteString("request timer timeout exceeded. Please try again") return } if ctx.Input.Query("signature") != Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("auth failed") + ctx.WriteString("authentication failed") } } } -// Signature used to generate signature with the appsecret/method/params/RequestURI +// Signature generates signature with appsecret/method/params/RequestURI func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { var b bytes.Buffer keys := make([]string, len(params)) From e7d8bab5d981ebf633f87b3ecdeda7525b13820b Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 5 Aug 2020 17:56:11 +0100 Subject: [PATCH 072/207] Improved definition of DefaultEvery --- pkg/cache/memory.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index 792a628a..c0e35c6c 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -22,7 +22,7 @@ import ( ) var ( - // Recycle the expired cache items in memory (in seconds) + // Timer for how often to recycle the expired cache items in memory (in seconds) DefaultEvery = 60 // 1 minute ) @@ -81,7 +81,7 @@ func (bc *MemoryCache) GetMulti(names []string) []interface{} { } // Put puts cache into memory. -// If lifespan is 0, it will never overwrite this value +// If lifespan is 0, it will never overwrite this value unless restarted func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { bc.Lock() defer bc.Unlock() From ec55edfbc411b8fce417da774feb86ce458a91aa Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Thu, 6 Aug 2020 11:14:36 +0200 Subject: [PATCH 073/207] Add additional options to redis session prov Adding option for frequency of checking timed out connections as well as an option to specify retries. These changes make redis provider more stable since connection problems are becoming fewer. Since redigo does not have this options and since redis_sentinel and redis_cluster are using go-redis as a client, this commit changes from redigo to go-redis for redis session provider. Added tests for redis session provider as well. --- go.mod | 4 +- go.sum | 50 +++-------- pkg/session/redis/sess_redis.go | 89 +++++++++---------- pkg/session/redis/sess_redis_test.go | 88 ++++++++++++++++++ pkg/session/redis_cluster/redis_cluster.go | 26 +++++- .../redis_sentinel/sess_redis_sentinel.go | 26 +++++- 6 files changed, 189 insertions(+), 94 deletions(-) create mode 100644 pkg/session/redis/sess_redis_test.go diff --git a/go.mod b/go.mod index a6c27488..1951d76d 100644 --- a/go.mod +++ b/go.mod @@ -12,8 +12,8 @@ require ( github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect github.com/elastic/go-elasticsearch/v6 v6.8.5 github.com/elazarl/go-bindata-assetfs v1.0.0 - github.com/go-kit/kit v0.9.0 github.com/go-redis/redis v6.14.2+incompatible + github.com/go-redis/redis/v7 v7.4.0 github.com/go-sql-driver/mysql v1.5.0 github.com/gogo/protobuf v1.1.1 github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect @@ -31,8 +31,6 @@ require ( github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 - golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect - google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 ) diff --git a/go.sum b/go.sum index 12b76333..6c1bfe10 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= @@ -21,13 +20,10 @@ github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVx github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ= github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d h1:OMrhQqj1QCyDT2sxHCDjE+k8aMdn2ngTCGG7g4wrdLo= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= @@ -45,9 +41,6 @@ github.com/elastic/go-elasticsearch/v6 v6.8.5 h1:U2HtkBseC1FNBmDr0TR2tKltL6FxoY+ github.com/elastic/go-elasticsearch/v6 v6.8.5/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox93HoX8utj1kxD9aFUcI= github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/glendc/gopher-json v0.0.0-20170414221815-dc4743023d0c/go.mod h1:Gja1A+xZ9BoviGJNA2E9vFkPjjsl+CoJxSXiQM1UXtw= @@ -59,6 +52,8 @@ github.com/go-logfmt/logfmt v0.4.0 h1:MP4Eh7ZCb31lleYCFuwm0oe4/YGak+5l1vA2NOE80n github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0= github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v7 v7.4.0 h1:7obg6wUoj05T0EpY0o8B59S9w5yeMWql7sw2kwNW1x4= +github.com/go-redis/redis/v7 v7.4.0/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= @@ -67,12 +62,9 @@ github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAf github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -85,7 +77,6 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= @@ -121,8 +112,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= @@ -143,7 +136,6 @@ github.com/prometheus/client_golang v1.7.0 h1:wCi7urQOGBsYcQROHqpUUX4ct84xp40t9R github.com/prometheus/client_golang v1.7.0/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= @@ -181,55 +173,35 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -252,5 +224,3 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index cf2dddf4..7e991ef5 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -41,7 +41,7 @@ import ( "github.com/astaxie/beego/pkg/session" - "github.com/gomodule/redigo/redis" + "github.com/go-redis/redis/v7" ) var redispder = &Provider{} @@ -51,7 +51,7 @@ var MaxPoolSize = 100 // SessionStore redis session store type SessionStore struct { - p *redis.Pool + p *redis.Client sid string lock sync.RWMutex values map[interface{}]interface{} @@ -103,9 +103,8 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { if err != nil { return } - c := rs.p.Get() - defer c.Close() - c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) } // Provider redis session provider @@ -115,7 +114,7 @@ type Provider struct { poolsize int password string dbNum int - poollist *redis.Pool + poollist *redis.Client } // SessionInit init redis session @@ -157,45 +156,40 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { idleTimeout = time.Duration(timeout) * time.Second } } - rp.poollist = &redis.Pool{ - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", rp.savePath) - if err != nil { - return nil, err - } - if rp.password != "" { - if _, err = c.Do("AUTH", rp.password); err != nil { - c.Close() - return nil, err - } - } - // some redis proxy such as twemproxy is not support select command - if rp.dbNum > 0 { - _, err = c.Do("SELECT", rp.dbNum) - if err != nil { - c.Close() - return nil, err - } - } - return c, err - }, - MaxIdle: rp.poolsize, + var idleCheckFrequency time.Duration = 0 + if len(configs) > 5 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + var maxRetries = 0 + if len(configs) > 6 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + maxRetries = retries + } } - rp.poollist.IdleTimeout = idleTimeout + rp.poollist = redis.NewClient(&redis.Options{ + Addr: rp.savePath, + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + IdleTimeout: idleTimeout, + IdleCheckFrequency: idleCheckFrequency, + MaxRetries: maxRetries, + }) - return rp.poollist.Get().Err() + return rp.poollist.Ping().Err() } // SessionRead read redis session by sid func (rp *Provider) SessionRead(sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - var kv map[interface{}]interface{} - kvs, err := redis.String(c.Do("GET", sid)) - if err != nil && err != redis.ErrNil { + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != redis.Nil { return nil, err } if len(kvs) == 0 { @@ -212,10 +206,9 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist check redis session exist by sid func (rp *Provider) SessionExist(sid string) (bool, error) { - c := rp.poollist.Get() - defer c.Close() + c := rp.poollist - if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err } return true, nil @@ -223,27 +216,24 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { // SessionRegenerate generate new sid for redis session func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { + c := rp.poollist + if existed, _ := c.Exists(oldsid).Result(); existed == 0 { // oldsid doesn't exists, set the new sid directly // ignore error here, since if it return error // the existed value will be 0 - c.Do("SET", sid, "", "EX", rp.maxlifetime) + c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime) } else { - c.Do("RENAME", oldsid, sid) - c.Do("EXPIRE", sid, rp.maxlifetime) + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime)) } return rp.SessionRead(sid) } // SessionDestroy delete redis session by id func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist.Get() - defer c.Close() + c := rp.poollist - c.Do("DEL", sid) + c.Del(sid) return nil } @@ -259,3 +249,4 @@ func (rp *Provider) SessionAll() int { func init() { session.Register("redis", redispder) } + diff --git a/pkg/session/redis/sess_redis_test.go b/pkg/session/redis/sess_redis_test.go new file mode 100644 index 00000000..db5bb2c7 --- /dev/null +++ b/pkg/session/redis/sess_redis_test.go @@ -0,0 +1,88 @@ +package redis + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/pkg/session" +) + +func TestRedis(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,30", + } + globalSession, err := session.NewManager("redis", sessionConfig) + if err != nil { + t.Fatal("could not create manager:", err) + } + + go globalSession.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSession.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) +} diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index 8b20ab19..75dc0e63 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -34,7 +34,7 @@ package redis_cluster import ( "github.com/astaxie/beego/pkg/session" - rediss "github.com/go-redis/redis" + rediss "github.com/go-redis/redis/v7" "net/http" "strconv" "strings" @@ -147,11 +147,35 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } + var idleTimeout time.Duration = 0 + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second + } + } + var idleCheckFrequency time.Duration = 0 + if len(configs) > 5 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + var maxRetries = 0 + if len(configs) > 6 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + maxRetries = retries + } + } rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ Addrs: strings.Split(rp.savePath, ";"), Password: rp.password, PoolSize: rp.poolsize, + IdleTimeout: idleTimeout, + IdleCheckFrequency: idleCheckFrequency, + MaxRetries: maxRetries, }) return rp.poollist.Ping().Err() } diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go index f78486be..da287b8d 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -34,7 +34,7 @@ package redis_sentinel import ( "github.com/astaxie/beego/pkg/session" - "github.com/go-redis/redis" + "github.com/go-redis/redis/v7" "net/http" "strconv" "strings" @@ -157,6 +157,27 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.masterName = "mymaster" } + var idleTimeout time.Duration = 0 + if len(configs) > 5 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second + } + } + var idleCheckFrequency time.Duration = 0 + if len(configs) > 6 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + var maxRetries = 0 + if len(configs) > 7 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + maxRetries = retries + } + } rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ SentinelAddrs: strings.Split(rp.savePath, ";"), @@ -164,6 +185,9 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { PoolSize: rp.poolsize, DB: rp.dbNum, MasterName: rp.masterName, + IdleTimeout: idleTimeout, + IdleCheckFrequency: idleCheckFrequency, + MaxRetries: maxRetries, }) return rp.poollist.Ping().Err() From 1b4bb43df02eb5ab3fe0280891f4467bb9c9aa9a Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 6 Aug 2020 16:07:18 +0100 Subject: [PATCH 074/207] More minor grammar fixes --- pkg/cache/cache.go | 22 +++++------ pkg/cache/file.go | 2 +- pkg/cache/memcache/memcache.go | 22 +++++------ pkg/cache/redis/redis.go | 31 ++++++++------- pkg/cache/ssdb/ssdb.go | 27 ++++++------- pkg/config/ini/ini.go | 2 +- pkg/config/json/json.go | 2 +- pkg/config/xml/xml.go | 2 +- pkg/config/yaml/yaml.go | 2 +- pkg/context/acceptencoder.go | 30 +++++++-------- pkg/context/context.go | 47 +++++++++++------------ pkg/context/input.go | 40 ++++++++++---------- pkg/context/output.go | 50 ++++++++++++------------ pkg/context/param/methodparams.go | 4 +- pkg/context/renderer.go | 2 +- pkg/context/response.go | 8 ++-- pkg/grace/server.go | 4 +- pkg/httplib/httplib.go | 63 ++++++++++++++++--------------- pkg/logs/accesslog.go | 2 +- pkg/logs/alils/alils.go | 14 +++---- pkg/logs/alils/config.go | 4 +- pkg/logs/alils/log.pb.go | 53 ++++++++++++++------------ pkg/logs/alils/log_config.go | 6 +-- pkg/logs/alils/log_project.go | 2 +- pkg/logs/alils/log_store.go | 6 +-- pkg/logs/alils/machine_group.go | 10 ++--- pkg/logs/conn.go | 12 +++--- pkg/logs/console.go | 10 ++--- pkg/logs/es/es.go | 4 +- pkg/logs/file.go | 10 ++--- pkg/logs/jianliao.go | 6 +-- pkg/logs/log.go | 16 ++++---- pkg/logs/slack.go | 4 +- pkg/logs/smtp.go | 6 +-- 34 files changed, 263 insertions(+), 262 deletions(-) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 82585c4e..049fb758 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -47,23 +47,23 @@ import ( // c.Incr("counter") // now is 2 // count := c.Get("counter").(int) type Cache interface { - // get cached value by key. + // Get a cached value by key. Get(key string) interface{} // GetMulti is a batch version of Get. GetMulti(keys []string) []interface{} - // set cached value with key and expire time. + // Set a cached value with key and expire time. Put(key string, val interface{}, timeout time.Duration) error - // delete cached value by key. + // Delete cached value by key. Delete(key string) error - // increase cached int value by key, as a counter. + // Increment a cached int value by key, as a counter. Incr(key string) error - // decrease cached int value by key, as a counter. + // Decrement a cached int value by key, as a counter. Decr(key string) error - // check if cached value exists or not. + // Check if a cached value exists or not. IsExist(key string) bool - // clear all cache. + // Clear all cache. ClearAll() error - // start gc routine based on config string settings. + // Start gc routine based on config string settings. StartAndGC(config string) error } @@ -85,9 +85,9 @@ func Register(name string, adapter Instance) { adapters[name] = adapter } -// NewCache Create a new cache driver by adapter name and config string. -// config need to be correct JSON as string: {"interval":360}. -// it will start gc automatically. +// NewCache creates a new cache driver by adapter name and config string. +// config: must be in JSON format such as {"interval":360}. +// Starts gc automatically. func NewCache(adapterName, config string) (adapter Cache, err error) { instanceFunc, ok := adapters[adapterName] if !ok { diff --git a/pkg/cache/file.go b/pkg/cache/file.go index dcc60bc0..0e5c44be 100644 --- a/pkg/cache/file.go +++ b/pkg/cache/file.go @@ -54,7 +54,7 @@ type FileCache struct { EmbedExpiry int } -// NewFileCache cerates a new file cache with no config. +// NewFileCache creates a new file cache with no config. // The level and expiry need to be set in the method StartAndGC as config string. func NewFileCache() Cache { // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go index b08596eb..94fc61dc 100644 --- a/pkg/cache/memcache/memcache.go +++ b/pkg/cache/memcache/memcache.go @@ -46,7 +46,7 @@ type Cache struct { conninfo []string } -// NewMemCache create new memcache adapter. +// NewMemCache creates a new memcache adapter. func NewMemCache() cache.Cache { return &Cache{} } @@ -64,7 +64,7 @@ func (rc *Cache) Get(key string) interface{} { return nil } -// GetMulti get value from memcache. +// GetMulti gets a value from a key in memcache. func (rc *Cache) GetMulti(keys []string) []interface{} { size := len(keys) var rv []interface{} @@ -89,7 +89,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { return rv } -// Put put value to memcache. +// Put puts a value into memcache. func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -107,7 +107,7 @@ func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { return rc.conn.Set(&item) } -// Delete delete value in memcache. +// Delete deletes a value in memcache. func (rc *Cache) Delete(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -117,7 +117,7 @@ func (rc *Cache) Delete(key string) error { return rc.conn.Delete(key) } -// Incr increase counter. +// Incr increases counter. func (rc *Cache) Incr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -128,7 +128,7 @@ func (rc *Cache) Incr(key string) error { return err } -// Decr decrease counter. +// Decr decreases counter. func (rc *Cache) Decr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -139,7 +139,7 @@ func (rc *Cache) Decr(key string) error { return err } -// IsExist check value exists in memcache. +// IsExist checks if a value exists in memcache. func (rc *Cache) IsExist(key string) bool { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -150,7 +150,7 @@ func (rc *Cache) IsExist(key string) bool { return err == nil } -// ClearAll clear all cached in memcache. +// ClearAll clears all cache in memcache. func (rc *Cache) ClearAll() error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -160,9 +160,9 @@ func (rc *Cache) ClearAll() error { return rc.conn.FlushAll() } -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. +// StartAndGC starts the memcache adapter. +// config: must be in the format {"conn":"connection info"}. +// If an error occurs during connecting, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go index a5fec591..68a934bf 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/cache/redis/redis.go @@ -43,7 +43,7 @@ import ( ) var ( - // DefaultKey the collection name of redis for cache adapter. + // The collection name of redis for the cache adapter. DefaultKey = "beecacheRedis" ) @@ -56,16 +56,16 @@ type Cache struct { password string maxIdle int - // the timeout to a value less than the redis server's timeout. + // Timeout value (less than the redis server's timeout value) timeout time.Duration } -// NewRedisCache create new redis cache with default collection name. +// NewRedisCache creates a new redis cache with default collection name. func NewRedisCache() cache.Cache { return &Cache{key: DefaultKey} } -// actually do the redis cmds, args[0] must be the key name. +// Execute the redis commands. args[0] must be the key name func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { if len(args) < 1 { return nil, errors.New("missing required arguments") @@ -90,7 +90,7 @@ func (rc *Cache) Get(key string) interface{} { return nil } -// GetMulti get cache from redis. +// GetMulti gets cache from redis. func (rc *Cache) GetMulti(keys []string) []interface{} { c := rc.p.Get() defer c.Close() @@ -105,19 +105,19 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { return values } -// Put put cache to redis. +// Put puts cache into redis. func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) return err } -// Delete delete cache in redis. +// Delete deletes a key's cache in redis. func (rc *Cache) Delete(key string) error { _, err := rc.do("DEL", key) return err } -// IsExist check cache's existence in redis. +// IsExist checks cache's existence in redis. func (rc *Cache) IsExist(key string) bool { v, err := redis.Bool(rc.do("EXISTS", key)) if err != nil { @@ -126,19 +126,19 @@ func (rc *Cache) IsExist(key string) bool { return v } -// Incr increase counter in redis. +// Incr increases a key's counter in redis. func (rc *Cache) Incr(key string) error { _, err := redis.Bool(rc.do("INCRBY", key, 1)) return err } -// Decr decrease counter in redis. +// Decr decreases a key's counter in redis. func (rc *Cache) Decr(key string) error { _, err := redis.Bool(rc.do("INCRBY", key, -1)) return err } -// ClearAll clean all cache in redis. delete this redis collection. +// ClearAll deletes all cache in the redis collection func (rc *Cache) ClearAll() error { cachedKeys, err := rc.Scan(rc.key + ":*") if err != nil { @@ -154,7 +154,7 @@ func (rc *Cache) ClearAll() error { return err } -// Scan scan all keys matching the pattern. a better choice than `keys` +// Scan scans all keys matching a given pattern. func (rc *Cache) Scan(pattern string) (keys []string, err error) { c := rc.p.Get() defer c.Close() @@ -183,10 +183,9 @@ func (rc *Cache) Scan(pattern string) (keys []string, err error) { } } -// StartAndGC start redis cache adapter. -// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} -// the cache item in redis are stored forever, -// so no gc operation. +// StartAndGC starts the redis cache adapter. +// config: must be in this format {"key":"collection key","conn":"connection info","dbNum":"0"} +// Cached items in redis are stored forever, no garbage collection happens func (rc *Cache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go index 62a63c60..038b2ebe 100644 --- a/pkg/cache/ssdb/ssdb.go +++ b/pkg/cache/ssdb/ssdb.go @@ -18,12 +18,12 @@ type Cache struct { conninfo []string } -//NewSsdbCache create new ssdb adapter. +//NewSsdbCache creates new ssdb adapter. func NewSsdbCache() cache.Cache { return &Cache{} } -// Get get value from memcache. +// Get gets a key's value from memcache. func (rc *Cache) Get(key string) interface{} { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -37,7 +37,7 @@ func (rc *Cache) Get(key string) interface{} { return nil } -// GetMulti get value from memcache. +// GetMulti gets one or keys values from memcache. func (rc *Cache) GetMulti(keys []string) []interface{} { size := len(keys) var values []interface{} @@ -63,7 +63,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { return values } -// DelMulti get value from memcache. +// DelMulti deletes one or more keys from memcache func (rc *Cache) DelMulti(keys []string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -74,7 +74,8 @@ func (rc *Cache) DelMulti(keys []string) error { return err } -// Put put value to memcache. only support string. +// Put puts value into memcache. +// value: must be of type string func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -102,7 +103,7 @@ func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error return errors.New("bad response") } -// Delete delete value in memcache. +// Delete deletes a value in memcache. func (rc *Cache) Delete(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -113,7 +114,7 @@ func (rc *Cache) Delete(key string) error { return err } -// Incr increase counter. +// Incr increases a key's counter. func (rc *Cache) Incr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -124,7 +125,7 @@ func (rc *Cache) Incr(key string) error { return err } -// Decr decrease counter. +// Decr decrements a key's counter. func (rc *Cache) Decr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -135,7 +136,7 @@ func (rc *Cache) Decr(key string) error { return err } -// IsExist check value exists in memcache. +// IsExist checks if a key exists in memcache. func (rc *Cache) IsExist(key string) bool { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -153,7 +154,7 @@ func (rc *Cache) IsExist(key string) bool { } -// ClearAll clear all cached in memcache. +// ClearAll clears all cached items in memcache. func (rc *Cache) ClearAll() error { if rc.conn == nil { if err := rc.connectInit(); err != nil { @@ -195,9 +196,9 @@ func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, erro return resp, nil } -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. +// StartAndGC starts the memcache adapter. +// config: must be in the format {"conn":"connection info"}. +// If an error occurs during connection, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) diff --git a/pkg/config/ini/ini.go b/pkg/config/ini/ini.go index a3c6462d..17408d85 100644 --- a/pkg/config/ini/ini.go +++ b/pkg/config/ini/ini.go @@ -224,7 +224,7 @@ func (ini *IniConfig) ParseData(data []byte) (config.Configer, error) { return ini.parseData(dir, data) } -// IniConfigContainer A Config represents the ini configuration. +// IniConfigContainer is a config which represents the ini configuration. // When set and get value, support key as section:name type. type IniConfigContainer struct { data map[string]map[string]string // section=> key:val diff --git a/pkg/config/json/json.go b/pkg/config/json/json.go index 49bd38ff..ede3cce5 100644 --- a/pkg/config/json/json.go +++ b/pkg/config/json/json.go @@ -66,7 +66,7 @@ func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { return x, nil } -// JSONConfigContainer A Config represents the json configuration. +// JSONConfigContainer is a config which represents the json configuration. // Only when get value, support key as section:name type. type JSONConfigContainer struct { data map[string]interface{} diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go index b1cce5c8..d8c018e6 100644 --- a/pkg/config/xml/xml.go +++ b/pkg/config/xml/xml.go @@ -72,7 +72,7 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { return x, nil } -// ConfigContainer A Config represents the xml configuration. +// ConfigContainer is a Config which represents the xml configuration. type ConfigContainer struct { data map[string]interface{} sync.Mutex diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go index 3dcb45fd..63a30208 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/config/yaml/yaml.go @@ -116,7 +116,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { return } -// ConfigContainer A Config represents the yaml configuration. +// ConfigContainer is a config which represents the yaml configuration. type ConfigContainer struct { data map[string]interface{} sync.RWMutex diff --git a/pkg/context/acceptencoder.go b/pkg/context/acceptencoder.go index b4e2492c..8ed6a853 100644 --- a/pkg/context/acceptencoder.go +++ b/pkg/context/acceptencoder.go @@ -28,18 +28,18 @@ import ( ) var ( - //Default size==20B same as nginx + // Default size==20B same as nginx defaultGzipMinLength = 20 - //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + // Content will only be compressed if content length is either unknown or greater than gzipMinLength. gzipMinLength = defaultGzipMinLength - //The compression level used for deflate compression. (0-9). + // Compression level used for deflate compression. (0-9). gzipCompressLevel int - //List of HTTP methods to compress. If not set, only GET requests are compressed. + // List of HTTP methods to compress. If not set, only GET requests are compressed. includedMethods map[string]bool getMethodOnly bool ) -// InitGzip init the gzipcompress +// InitGzip initializes the gzipcompress func InitGzip(minLength, compressLevel int, methods []string) { if minLength >= 0 { gzipMinLength = minLength @@ -98,9 +98,9 @@ func (ac acceptEncoder) put(wr resetWriter, level int) { } wr.Reset(nil) - //notice - //compressionLevel==BestCompression DOES NOT MATTER - //sync.Pool will not memory leak + // notice + // compressionLevel==BestCompression DOES NOT MATTER + // sync.Pool will not memory leak switch level { case gzipCompressLevel: @@ -119,10 +119,10 @@ var ( bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, } - //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed - //deflate - //The "zlib" format defined in RFC 1950 [31] in combination with - //the "deflate" compression mechanism described in RFC 1951 [29]. + // According to: http://tools.ietf.org/html/rfc2616#section-3.5 the deflate compress in http is zlib indeed + // deflate + // The "zlib" format defined in RFC 1950 [31] in combination with + // the "deflate" compression mechanism described in RFC 1951 [29]. deflateCompressEncoder = acceptEncoder{ name: "deflate", levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, @@ -145,7 +145,7 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, return writeLevel(encoding, writer, file, flate.BestCompression) } -// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { if encoding == "" || len(content) < gzipMinLength { _, err := writer.Write(content) @@ -154,8 +154,8 @@ func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) } -// writeLevel reads from reader,writes to writer by specific encoding and compress level -// the compress level is defined by deflate package +// writeLevel reads from reader and writes to writer by specific encoding and compress level. +// The compress level is defined by deflate package func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { var outputWriter resetWriter var err error diff --git a/pkg/context/context.go b/pkg/context/context.go index 9f974551..f7b325a9 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/utils" ) -//commonly used mime-types +// Commonly used mime-types const ( ApplicationJSON = "application/json" ApplicationXML = "application/xml" @@ -55,7 +55,7 @@ func NewContext() *Context { } // Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. -// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +// BeegoInput and BeegoOutput provides an api to operate request and response more easily. type Context struct { Input *BeegoInput Output *BeegoOutput @@ -64,7 +64,7 @@ type Context struct { _xsrfToken string } -// Reset init Context, BeegoInput and BeegoOutput +// Reset initializes Context, BeegoInput and BeegoOutput func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx.Request = r if ctx.ResponseWriter == nil { @@ -76,37 +76,36 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx._xsrfToken = "" } -// Redirect does redirection to localurl with http header status code. +// Redirect redirects to localurl with http header status code. func (ctx *Context) Redirect(status int, localurl string) { http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) } -// Abort stops this request. -// if beego.ErrorMaps exists, panic body. +// Abort stops the request. +// If beego.ErrorMaps exists, panic body. func (ctx *Context) Abort(status int, body string) { ctx.Output.SetStatus(status) panic(body) } -// WriteString Write string to response body. -// it sends response body. +// WriteString writes a string to response body. func (ctx *Context) WriteString(content string) { ctx.ResponseWriter.Write([]byte(content)) } -// GetCookie Get cookie from request by a given key. -// It's alias of BeegoInput.Cookie. +// GetCookie gets a cookie from a request for a given key. +// (Alias of BeegoInput.Cookie) func (ctx *Context) GetCookie(key string) string { return ctx.Input.Cookie(key) } -// SetCookie Set cookie for response. -// It's alias of BeegoOutput.Cookie. +// SetCookie sets a cookie for a response. +// (Alias of BeegoOutput.Cookie) func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { ctx.Output.Cookie(name, value, others...) } -// GetSecureCookie Get secure cookie from request by a given key. +// GetSecureCookie gets a secure cookie from a request for a given key. func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { val := ctx.Input.Cookie(key) if val == "" { @@ -133,7 +132,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { return string(res), true } -// SetSecureCookie Set Secure cookie for response. +// SetSecureCookie sets a secure cookie for a response. func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { vs := base64.URLEncoding.EncodeToString([]byte(value)) timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) @@ -144,7 +143,7 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf ctx.Output.Cookie(name, cookie, others...) } -// XSRFToken creates a xsrf token string and returns. +// XSRFToken creates and returns an xsrf token string func (ctx *Context) XSRFToken(key string, expire int64) string { if ctx._xsrfToken == "" { token, ok := ctx.GetSecureCookie(key, "_xsrf") @@ -157,8 +156,8 @@ func (ctx *Context) XSRFToken(key string, expire int64) string { return ctx._xsrfToken } -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// CheckXSRFCookie checks if the XSRF token in this request is valid or not. +// The token can be provided in the request header in the form "X-Xsrftoken" or "X-CsrfToken" // or in form field value named as "_xsrf". func (ctx *Context) CheckXSRFCookie() bool { token := ctx.Input.Query("_xsrf") @@ -195,8 +194,8 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } } -//Response is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler +// Response is a wrapper for the http.ResponseWriter +// Started: if true, response was already written to so the other handler will not be executed type Response struct { http.ResponseWriter Started bool @@ -210,16 +209,16 @@ func (r *Response) reset(rw http.ResponseWriter) { r.Started = false } -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. +// Write writes the data to the connection as part of a HTTP reply, +// and sets `Started` to true. +// Started: if true, the response was already sent func (r *Response) Write(p []byte) (int, error) { r.Started = true return r.ResponseWriter.Write(p) } -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. +// WriteHeader sends a HTTP response header with status code, +// and sets `Started` to true. func (r *Response) WriteHeader(code int) { if r.Status > 0 { //prevent multiple response.WriteHeader calls diff --git a/pkg/context/input.go b/pkg/context/input.go index 04347e04..5ff85f43 100644 --- a/pkg/context/input.go +++ b/pkg/context/input.go @@ -43,7 +43,7 @@ var ( ) // BeegoInput operates the http request header, data, cookie and body. -// it also contains router params and current session. +// Contains router params and current session. type BeegoInput struct { Context *Context CruSession session.Store @@ -56,7 +56,7 @@ type BeegoInput struct { RunController reflect.Type } -// NewInput return BeegoInput generated by Context. +// NewInput returns the BeegoInput generated by context. func NewInput() *BeegoInput { return &BeegoInput{ pnames: make([]string, 0, maxParam), @@ -65,7 +65,7 @@ func NewInput() *BeegoInput { } } -// Reset init the BeegoInput +// Reset initializes the BeegoInput func (input *BeegoInput) Reset(ctx *Context) { input.Context = ctx input.CruSession = nil @@ -77,27 +77,27 @@ func (input *BeegoInput) Reset(ctx *Context) { input.RequestBody = []byte{} } -// Protocol returns request protocol name, such as HTTP/1.1 . +// Protocol returns the request protocol name, such as HTTP/1.1 . func (input *BeegoInput) Protocol() string { return input.Context.Request.Proto } -// URI returns full request url with query string, fragment. +// URI returns the full request url with query, string and fragment. func (input *BeegoInput) URI() string { return input.Context.Request.RequestURI } -// URL returns request url path (without query string, fragment). +// URL returns the request url path (without query, string and fragment). func (input *BeegoInput) URL() string { return input.Context.Request.URL.EscapedPath() } -// Site returns base site url as scheme://domain type. +// Site returns the base site url as scheme://domain type. func (input *BeegoInput) Site() string { return input.Scheme() + "://" + input.Domain() } -// Scheme returns request scheme as "http" or "https". +// Scheme returns the request scheme as "http" or "https". func (input *BeegoInput) Scheme() string { if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { return scheme @@ -111,14 +111,13 @@ func (input *BeegoInput) Scheme() string { return "https" } -// Domain returns host name. -// Alias of Host method. +// Domain returns the host name (alias of host method) func (input *BeegoInput) Domain() string { return input.Host() } -// Host returns host name. -// if no host info in request, return localhost. +// Host returns the host name. +// If no host info in request, return localhost. func (input *BeegoInput) Host() string { if input.Context.Request.Host != "" { if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { @@ -134,7 +133,7 @@ func (input *BeegoInput) Method() string { return input.Context.Request.Method } -// Is returns boolean of this request is on given method, such as Is("POST"). +// Is returns the boolean value of this request is on given method, such as Is("POST"). func (input *BeegoInput) Is(method string) bool { return input.Method() == method } @@ -174,7 +173,7 @@ func (input *BeegoInput) IsPatch() bool { return input.Is("PATCH") } -// IsAjax returns boolean of this request is generated by ajax. +// IsAjax returns boolean of is this request generated by ajax. func (input *BeegoInput) IsAjax() bool { return input.Header("X-Requested-With") == "XMLHttpRequest" } @@ -251,7 +250,7 @@ func (input *BeegoInput) Refer() string { } // SubDomains returns sub domain string. -// if aa.bb.domain.com, returns aa.bb . +// if aa.bb.domain.com, returns aa.bb func (input *BeegoInput) SubDomains() string { parts := strings.Split(input.Host(), ".") if len(parts) >= 3 { @@ -306,7 +305,7 @@ func (input *BeegoInput) Params() map[string]string { return m } -// SetParam will set the param with key and value +// SetParam sets the param with key and value func (input *BeegoInput) SetParam(key, val string) { // check if already exists for i, v := range input.pnames { @@ -319,9 +318,8 @@ func (input *BeegoInput) SetParam(key, val string) { input.pnames = append(input.pnames, key) } -// ResetParams clears any of the input's Params -// This function is used to clear parameters so they may be reset between filter -// passes. +// ResetParams clears any of the input's params +// Used to clear parameters so they may be reset between filter passes. func (input *BeegoInput) ResetParams() { input.pnames = input.pnames[:0] input.pvalues = input.pvalues[:0] @@ -391,7 +389,7 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { return requestbody } -// Data return the implicit data in the input +// Data returns the implicit data in the input func (input *BeegoInput) Data() map[interface{}]interface{} { input.dataLock.Lock() defer input.dataLock.Unlock() @@ -412,7 +410,7 @@ func (input *BeegoInput) GetData(key interface{}) interface{} { } // SetData stores data with given key in this context. -// This data are only available in this context. +// This data is only available in this context. func (input *BeegoInput) SetData(key, val interface{}) { input.dataLock.Lock() defer input.dataLock.Unlock() diff --git a/pkg/context/output.go b/pkg/context/output.go index 238dcf45..0a530244 100644 --- a/pkg/context/output.go +++ b/pkg/context/output.go @@ -42,12 +42,12 @@ type BeegoOutput struct { } // NewOutput returns new BeegoOutput. -// it contains nothing now. +// Empty when initialized func NewOutput() *BeegoOutput { return &BeegoOutput{} } -// Reset init BeegoOutput +// Reset initializes BeegoOutput func (output *BeegoOutput) Reset(ctx *Context) { output.Context = ctx output.Status = 0 @@ -58,9 +58,9 @@ func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) } -// Body sets response body content. -// if EnableGzip, compress content string. -// it sends out response body directly. +// Body sets the response body content. +// if EnableGzip, content is compressed. +// Sends out response body directly. func (output *BeegoOutput) Body(content []byte) error { var encoding string var buf = &bytes.Buffer{} @@ -85,13 +85,13 @@ func (output *BeegoOutput) Body(content []byte) error { return nil } -// Cookie sets cookie value via given key. -// others are ordered as cookie's max age time, path,domain, secure and httponly. +// Cookie sets a cookie value via given key. +// others: used to set a cookie's max age time, path,domain, secure and httponly. func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { var b bytes.Buffer fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) - //fix cookie not work in IE + // fix cookie not work in IE if len(others) > 0 { var maxAge int64 @@ -183,7 +183,7 @@ func errorRenderer(err error) Renderer { }) } -// JSON writes json to response body. +// JSON writes json to the response body. // if encoding is true, it converts utf-8 to \u0000 type. func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { output.Header("Content-Type", "application/json; charset=utf-8") @@ -204,7 +204,7 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) return output.Body(content) } -// YAML writes yaml to response body. +// YAML writes yaml to the response body. func (output *BeegoOutput) YAML(data interface{}) error { output.Header("Content-Type", "application/x-yaml; charset=utf-8") var content []byte @@ -217,7 +217,7 @@ func (output *BeegoOutput) YAML(data interface{}) error { return output.Body(content) } -// JSONP writes jsonp to response body. +// JSONP writes jsonp to the response body. func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/javascript; charset=utf-8") var content []byte @@ -243,7 +243,7 @@ func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { return output.Body(callbackContent.Bytes()) } -// XML writes xml string to response body. +// XML writes xml string to the response body. func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/xml; charset=utf-8") var content []byte @@ -260,7 +260,7 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { return output.Body(content) } -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +// ServeFormatted serves YAML, XML or JSON, depending on the value of the Accept header func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { accept := output.Context.Input.Header("Accept") switch accept { @@ -274,7 +274,7 @@ func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasE } // Download forces response for download file. -// it prepares the download response header automatically. +// Prepares the download response header automatically. func (output *BeegoOutput) Download(file string, filename ...string) { // check get file error, file not found or other error. if _, err := os.Stat(file); err != nil { @@ -323,61 +323,61 @@ func (output *BeegoOutput) ContentType(ext string) { } } -// SetStatus sets response status code. -// It writes response header directly. +// SetStatus sets the response status code. +// Writes response header directly. func (output *BeegoOutput) SetStatus(status int) { output.Status = status } -// IsCachable returns boolean of this request is cached. +// IsCachable returns boolean of if this request is cached. // HTTP 304 means cached. func (output *BeegoOutput) IsCachable() bool { return output.Status >= 200 && output.Status < 300 || output.Status == 304 } -// IsEmpty returns boolean of this request is empty. +// IsEmpty returns boolean of if this request is empty. // HTTP 201,204 and 304 means empty. func (output *BeegoOutput) IsEmpty() bool { return output.Status == 201 || output.Status == 204 || output.Status == 304 } -// IsOk returns boolean of this request runs well. +// IsOk returns boolean of if this request was ok. // HTTP 200 means ok. func (output *BeegoOutput) IsOk() bool { return output.Status == 200 } -// IsSuccessful returns boolean of this request runs successfully. +// IsSuccessful returns boolean of this request was successful. // HTTP 2xx means ok. func (output *BeegoOutput) IsSuccessful() bool { return output.Status >= 200 && output.Status < 300 } -// IsRedirect returns boolean of this request is redirection header. +// IsRedirect returns boolean of if this request is redirected. // HTTP 301,302,307 means redirection. func (output *BeegoOutput) IsRedirect() bool { return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 } -// IsForbidden returns boolean of this request is forbidden. +// IsForbidden returns boolean of if this request is forbidden. // HTTP 403 means forbidden. func (output *BeegoOutput) IsForbidden() bool { return output.Status == 403 } -// IsNotFound returns boolean of this request is not found. +// IsNotFound returns boolean of if this request is not found. // HTTP 404 means not found. func (output *BeegoOutput) IsNotFound() bool { return output.Status == 404 } -// IsClientError returns boolean of this request client sends error data. +// IsClientError returns boolean of if this request client sends error data. // HTTP 4xx means client error. func (output *BeegoOutput) IsClientError() bool { return output.Status >= 400 && output.Status < 500 } -// IsServerError returns boolean of this server handler errors. +// IsServerError returns boolean of if this server handler errors. // HTTP 5xx means server internal error. func (output *BeegoOutput) IsServerError() bool { return output.Status >= 500 && output.Status < 600 diff --git a/pkg/context/param/methodparams.go b/pkg/context/param/methodparams.go index cd6708a2..b5ccbdd0 100644 --- a/pkg/context/param/methodparams.go +++ b/pkg/context/param/methodparams.go @@ -22,7 +22,7 @@ const ( header ) -//New creates a new MethodParam with name and specific options +// New creates a new MethodParam with name and specific options func New(name string, opts ...MethodParamOption) *MethodParam { return newParam(name, nil, opts) } @@ -35,7 +35,7 @@ func newParam(name string, parser paramParser, opts []MethodParamOption) (param return } -//Make creates an array of MethodParmas or an empty array +// Make creates an array of MethodParmas or an empty array func Make(list ...*MethodParam) []*MethodParam { if len(list) > 0 { return list diff --git a/pkg/context/renderer.go b/pkg/context/renderer.go index 36a7cb53..5a078332 100644 --- a/pkg/context/renderer.go +++ b/pkg/context/renderer.go @@ -1,6 +1,6 @@ package context -// Renderer defines an http response renderer +// Renderer defines a http response renderer type Renderer interface { Render(ctx *Context) } diff --git a/pkg/context/response.go b/pkg/context/response.go index 9c3c715a..d80cfe89 100644 --- a/pkg/context/response.go +++ b/pkg/context/response.go @@ -7,21 +7,21 @@ import ( ) const ( - //BadRequest indicates http error 400 + //BadRequest indicates HTTP error 400 BadRequest StatusCode = http.StatusBadRequest - //NotFound indicates http error 404 + //NotFound indicates HTTP error 404 NotFound StatusCode = http.StatusNotFound ) -// StatusCode sets the http response status code +// StatusCode sets the HTTP response status code type StatusCode int func (s StatusCode) Error() string { return strconv.Itoa(int(s)) } -// Render sets the http status code +// Render sets the HTTP status code func (s StatusCode) Render(ctx *Context) { ctx.Output.SetStatus(int(s)) } diff --git a/pkg/grace/server.go b/pkg/grace/server.go index 008a6171..13fa6e34 100644 --- a/pkg/grace/server.go +++ b/pkg/grace/server.go @@ -29,8 +29,8 @@ type Server struct { terminalChan chan error } -// Serve accepts incoming connections on the Listener l, -// creating a new service goroutine for each. +// Serve accepts incoming connections on the Listener l +// and creates a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { srv.state = StateRunning diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go index 60aa4e8b..1438a881 100644 --- a/pkg/httplib/httplib.go +++ b/pkg/httplib/httplib.go @@ -73,14 +73,14 @@ func createDefaultCookie() { defaultCookieJar, _ = cookiejar.New(nil) } -// SetDefaultSetting Overwrite default settings +// SetDefaultSetting overwrites default settings func SetDefaultSetting(setting BeegoHTTPSettings) { settingMutex.Lock() defer settingMutex.Unlock() defaultSetting = setting } -// NewBeegoRequest return *BeegoHttpRequest with specific method +// NewBeegoRequest returns *BeegoHttpRequest with specific method func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) @@ -147,7 +147,7 @@ type BeegoHTTPSettings struct { RetryDelay time.Duration } -// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url. type BeegoHTTPRequest struct { url string req *http.Request @@ -159,12 +159,12 @@ type BeegoHTTPRequest struct { dump []byte } -// GetRequest return the request object +// GetRequest returns the request object func (b *BeegoHTTPRequest) GetRequest() *http.Request { return b.req } -// Setting Change request settings +// Setting changes request settings func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { b.setting = setting return b @@ -195,26 +195,27 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { } // Retries sets Retries times. -// default is 0 means no retried. -// -1 means retried forever. -// others means retried times. +// default is 0 (never retries) +// -1 retry indefinitely (forever) +// Other numbers specify the exact retry amount func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { b.setting.Retries = times return b } +// RetryDelay sets the time to sleep between reconnection attempts func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { b.setting.RetryDelay = delay return b } -// DumpBody setting whether need to Dump the Body. +// DumpBody sets the DumbBody field func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump return b } -// DumpRequest return the DumpRequest +// DumpRequest returns the DumpRequest func (b *BeegoHTTPRequest) DumpRequest() []byte { return b.dump } @@ -226,13 +227,13 @@ func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Dura return b } -// SetTLSClientConfig sets tls connection configurations if visiting https url. +// SetTLSClientConfig sets TLS connection configuration if visiting HTTPS url. func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { b.setting.TLSClientConfig = config return b } -// Header add header item string in request. +// Header adds header item string in request. func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { b.req.Header.Set(key, value) return b @@ -244,7 +245,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { return b } -// SetProtocolVersion Set the protocol version for incoming requests. +// SetProtocolVersion sets the protocol version for incoming requests. // Client requests always use HTTP/1.1. func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { if len(vers) == 0 { @@ -261,19 +262,19 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { return b } -// SetCookie add cookie into request. +// SetCookie adds a cookie to the request. func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { b.req.Header.Add("Cookie", cookie.String()) return b } -// SetTransport set the setting transport +// SetTransport sets the transport field func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { b.setting.Transport = transport return b } -// SetProxy set the http proxy +// SetProxy sets the HTTP proxy // example: // // func(req *http.Request) (*url.URL, error) { @@ -305,14 +306,14 @@ func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { return b } -// PostFile add a post file to the request +// PostFile adds a post file to the request func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { b.files[formname] = filename return b } // Body adds request raw body. -// it supports string and []byte. +// Supports string and []byte. func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: @@ -327,7 +328,7 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { return b } -// XMLBody adds request raw body encoding by XML. +// XMLBody adds the request raw body encoded in XML. func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) @@ -341,7 +342,7 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { return b, nil } -// YAMLBody adds request raw body encoding by YAML. +// YAMLBody adds the request raw body encoded in YAML. func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) @@ -355,7 +356,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) return b, nil } -// JSONBody adds request raw body encoding by JSON. +// JSONBody adds the request raw body encoded in JSON. func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := json.Marshal(obj) @@ -437,7 +438,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { return resp, nil } -// DoRequest will do the client.Do +// DoRequest executes client.Do func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { var paramBody string if len(b.params) > 0 { @@ -530,7 +531,7 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { } // String returns the body string in response. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) String() (string, error) { data, err := b.Bytes() if err != nil { @@ -541,7 +542,7 @@ func (b *BeegoHTTPRequest) String() (string, error) { } // Bytes returns the body []byte in response. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.body != nil { return b.body, nil @@ -567,7 +568,7 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { } // ToFile saves the body data in response to one file. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToFile(filename string) error { resp, err := b.getResponse() if err != nil { @@ -590,7 +591,7 @@ func (b *BeegoHTTPRequest) ToFile(filename string) error { return err } -//Check that the file directory exists, there is no automatically created +// Check if the file directory exists. If it doesn't then it's created func pathExistAndMkdir(filename string) (err error) { filename = path.Dir(filename) _, err = os.Stat(filename) @@ -606,8 +607,8 @@ func pathExistAndMkdir(filename string) (err error) { return err } -// ToJSON returns the map that marshals from the body bytes as json in response . -// it calls Response inner. +// ToJSON returns the map that marshals from the body bytes as json in response. +// Calls Response inner. func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -617,7 +618,7 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { } // ToXML returns the map that marshals from the body bytes as xml in response . -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToXML(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -627,7 +628,7 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { } // ToYAML returns the map that marshals from the body bytes as yaml in response . -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -636,7 +637,7 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { return yaml.Unmarshal(data, v) } -// Response executes request client gets response mannually. +// Response executes request client gets response manually. func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() } diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go index 9011b602..e380c54a 100644 --- a/pkg/logs/accesslog.go +++ b/pkg/logs/accesslog.go @@ -28,7 +28,7 @@ const ( jsonFormat = "JSON_FORMAT" ) -// AccessLogRecord struct for holding access log data. +// AccessLogRecord is astruct for holding access log data. type AccessLogRecord struct { RemoteAddr string `json:"remote_addr"` RequestTime time.Time `json:"request_time"` diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 8397b3da..fd1a4e28 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -11,9 +11,9 @@ import ( ) const ( - // CacheSize set the flush size + // CacheSize sets the flush size CacheSize int = 64 - // Delimiter define the topic delimiter + // Delimiter defines the topic delimiter Delimiter string = "##" ) @@ -31,7 +31,7 @@ type Config struct { } // aliLSWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. +// Writes messages in keep-live tcp connection. type aliLSWriter struct { store *LogStore group []*LogGroup @@ -41,14 +41,14 @@ type aliLSWriter struct { Config } -// NewAliLS create a new Logger +// NewAliLS creates a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace return alils } -// Init parse config and init struct +// Init parses config and initializes struct func (c *aliLSWriter) Init(jsonConfig string) (err error) { json.Unmarshal([]byte(jsonConfig), c) @@ -101,8 +101,8 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { return nil } -// WriteMsg write message in connection. -// if connection is down, try to re-connect. +// WriteMsg writes a message in connection. +// If connection is down, try to re-connect. func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { if level > c.Level { diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go index e8c24448..d0b67c24 100755 --- a/pkg/logs/alils/config.go +++ b/pkg/logs/alils/config.go @@ -4,10 +4,10 @@ const ( version = "0.5.0" // SDK version signatureMethod = "hmac-sha1" // Signature method - // OffsetNewest stands for the log head offset, i.e. the offset that will be + // OffsetNewest is the log head offset, i.e. the offset that will be // assigned to the next message that will be produced to the shard. OffsetNewest = "end" - // OffsetOldest stands for the oldest offset available on the logstore for a + // OffsetOldest is the the oldest offset available on the logstore for a // shard. OffsetOldest = "begin" ) diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go index 601b0d78..b18fb9b7 100755 --- a/pkg/logs/alils/log.pb.go +++ b/pkg/logs/alils/log.pb.go @@ -31,13 +31,13 @@ type Log struct { // Reset the Log func (m *Log) Reset() { *m = Log{} } -// String return the Compact Log +// String returns the Compact Log func (m *Log) String() string { return proto.CompactTextString(m) } // ProtoMessage not implemented func (*Log) ProtoMessage() {} -// GetTime return the Log's Time +// GetTime returns the Log's Time func (m *Log) GetTime() uint32 { if m != nil && m.Time != nil { return *m.Time @@ -45,7 +45,7 @@ func (m *Log) GetTime() uint32 { return 0 } -// GetContents return the Log's Contents +// GetContents returns the Log's Contents func (m *Log) GetContents() []*LogContent { if m != nil { return m.Contents @@ -53,7 +53,7 @@ func (m *Log) GetContents() []*LogContent { return nil } -// LogContent define the Log content struct +// LogContent defines the Log content struct type LogContent struct { Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` @@ -63,13 +63,13 @@ type LogContent struct { // Reset LogContent func (m *LogContent) Reset() { *m = LogContent{} } -// String return the compact text +// String returns the compact text func (m *LogContent) String() string { return proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogContent) ProtoMessage() {} -// GetKey return the Key +// GetKey returns the key func (m *LogContent) GetKey() string { if m != nil && m.Key != nil { return *m.Key @@ -77,7 +77,7 @@ func (m *LogContent) GetKey() string { return "" } -// GetValue return the Value +// GetValue returns the value func (m *LogContent) GetValue() string { if m != nil && m.Value != nil { return *m.Value @@ -85,7 +85,7 @@ func (m *LogContent) GetValue() string { return "" } -// LogGroup define the logs struct +// LogGroup defines the logs struct type LogGroup struct { Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` @@ -97,13 +97,13 @@ type LogGroup struct { // Reset LogGroup func (m *LogGroup) Reset() { *m = LogGroup{} } -// String return the compact text +// String returns the compact text func (m *LogGroup) String() string { return proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogGroup) ProtoMessage() {} -// GetLogs return the loggroup logs +// GetLogs returns the loggroup logs func (m *LogGroup) GetLogs() []*Log { if m != nil { return m.Logs @@ -111,7 +111,8 @@ func (m *LogGroup) GetLogs() []*Log { return nil } -// GetReserved return Reserved +// GetReserved returns Reserved. An empty string is returned +// if an error occurs func (m *LogGroup) GetReserved() string { if m != nil && m.Reserved != nil { return *m.Reserved @@ -119,7 +120,8 @@ func (m *LogGroup) GetReserved() string { return "" } -// GetTopic return Topic +// GetTopic returns Topic. An empty string is returned +// if an error occurs func (m *LogGroup) GetTopic() string { if m != nil && m.Topic != nil { return *m.Topic @@ -127,7 +129,8 @@ func (m *LogGroup) GetTopic() string { return "" } -// GetSource return Source +// GetSource returns source. An empty string is returned +// if an error occurs func (m *LogGroup) GetSource() string { if m != nil && m.Source != nil { return *m.Source @@ -135,7 +138,7 @@ func (m *LogGroup) GetSource() string { return "" } -// LogGroupList define the LogGroups +// LogGroupList defines the LogGroups type LogGroupList struct { LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` XXXUnrecognized []byte `json:"-"` @@ -144,13 +147,13 @@ type LogGroupList struct { // Reset LogGroupList func (m *LogGroupList) Reset() { *m = LogGroupList{} } -// String return compact text +// String returns compact text func (m *LogGroupList) String() string { return proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogGroupList) ProtoMessage() {} -// GetLogGroups return the LogGroups +// GetLogGroups returns the LogGroups func (m *LogGroupList) GetLogGroups() []*LogGroup { if m != nil { return m.LogGroups @@ -158,7 +161,7 @@ func (m *LogGroupList) GetLogGroups() []*LogGroup { return nil } -// Marshal the logs to byte slice +// Marshal marshals the logs to byte slice func (m *Log) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) @@ -353,7 +356,7 @@ func encodeVarintLog(data []byte, offset int, v uint64) int { return offset + 1 } -// Size return the log's size +// Size returns the log's size func (m *Log) Size() (n int) { var l int _ = l @@ -372,7 +375,7 @@ func (m *Log) Size() (n int) { return n } -// Size return LogContent size based on Key and Value +// Size returns LogContent size based on Key and Value func (m *LogContent) Size() (n int) { var l int _ = l @@ -390,7 +393,7 @@ func (m *LogContent) Size() (n int) { return n } -// Size return LogGroup size based on Logs +// Size returns LogGroup size based on Logs func (m *LogGroup) Size() (n int) { var l int _ = l @@ -418,7 +421,7 @@ func (m *LogGroup) Size() (n int) { return n } -// Size return LogGroupList size +// Size returns LogGroupList size func (m *LogGroupList) Size() (n int) { var l int _ = l @@ -448,7 +451,7 @@ func sozLog(x uint64) (n int) { return sovLog((x << 1) ^ (x >> 63)) } -// Unmarshal data to log +// Unmarshal unmarshals data to log func (m *Log) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) @@ -557,7 +560,7 @@ func (m *Log) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogContent +// Unmarshal unmarshals data to LogContent func (m *LogContent) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) @@ -679,7 +682,7 @@ func (m *LogContent) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogGroup +// Unmarshal unmarshals data to LogGroup func (m *LogGroup) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 @@ -853,7 +856,7 @@ func (m *LogGroup) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogGroupList +// Unmarshal unmarshals data to LogGroupList func (m *LogGroupList) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go index e8564efb..7daeb864 100755 --- a/pkg/logs/alils/log_config.go +++ b/pkg/logs/alils/log_config.go @@ -1,6 +1,6 @@ package alils -// InputDetail define log detail +// InputDetail defines log detail type InputDetail struct { LogType string `json:"logType"` LogPath string `json:"logPath"` @@ -15,13 +15,13 @@ type InputDetail struct { TopicFormat string `json:"topicFormat"` } -// OutputDetail define the output detail +// OutputDetail defines the output detail type OutputDetail struct { Endpoint string `json:"endpoint"` LogStoreName string `json:"logstoreName"` } -// LogConfig define Log Config +// LogConfig defines Log Config type LogConfig struct { Name string `json:"configName"` InputType string `json:"inputType"` diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go index 59db8cbf..7ede3fef 100755 --- a/pkg/logs/alils/log_project.go +++ b/pkg/logs/alils/log_project.go @@ -20,7 +20,7 @@ type errorMessage struct { Message string `json:"errorMessage"` } -// LogProject Define the Ali Project detail +// LogProject defines the Ali Project detail type LogProject struct { Name string // Project name Endpoint string // IP or hostname of SLS endpoint diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go index fa502736..d5ff25e2 100755 --- a/pkg/logs/alils/log_store.go +++ b/pkg/logs/alils/log_store.go @@ -12,7 +12,7 @@ import ( "github.com/gogo/protobuf/proto" ) -// LogStore Store the logs +// LogStore stores the logs type LogStore struct { Name string `json:"logstoreName"` TTL int @@ -24,7 +24,7 @@ type LogStore struct { project *LogProject } -// Shard define the Log Shard +// Shard defines the Log Shard type Shard struct { ShardID int `json:"shardID"` } @@ -71,7 +71,7 @@ func (s *LogStore) ListShards() (shardIDs []int, err error) { return } -// PutLogs put logs into logstore. +// PutLogs puts logs into logstore. // The callers should transform user logs into LogGroup. func (s *LogStore) PutLogs(lg *LogGroup) (err error) { body, err := proto.Marshal(lg) diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go index b6c69a14..101faeb4 100755 --- a/pkg/logs/alils/machine_group.go +++ b/pkg/logs/alils/machine_group.go @@ -8,13 +8,13 @@ import ( "net/http/httputil" ) -// MachineGroupAttribute define the Attribute +// MachineGroupAttribute defines the Attribute type MachineGroupAttribute struct { ExternalName string `json:"externalName"` TopicName string `json:"groupTopic"` } -// MachineGroup define the machine Group +// MachineGroup defines the machine Group type MachineGroup struct { Name string `json:"groupName"` Type string `json:"groupType"` @@ -29,20 +29,20 @@ type MachineGroup struct { project *LogProject } -// Machine define the Machine +// Machine defines the Machine type Machine struct { IP string UniqueID string `json:"machine-uniqueid"` UserdefinedID string `json:"userdefined-id"` } -// MachineList define the Machine List +// MachineList defines the Machine List type MachineList struct { Total int Machines []*Machine } -// ListMachines returns machine list of this machine group. +// ListMachines returns the machine list of this machine group. func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { h := map[string]string{ "x-sls-bodyrawsize": "0", diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 74c458ab..8b55bde7 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -22,7 +22,7 @@ import ( ) // connWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. +// Writes messages in keep-live tcp connection. type connWriter struct { lg *logWriter innerWriter io.WriteCloser @@ -33,21 +33,21 @@ type connWriter struct { Level int `json:"level"` } -// NewConn create new ConnWrite returning as LoggerInterface. +// NewConn creates new ConnWrite returning as LoggerInterface. func NewConn() Logger { conn := new(connWriter) conn.Level = LevelTrace return conn } -// Init init connection writer with json config. -// json config only need key "level". +// Init initializes a connection writer with json config. +// json config only needs they "level" key func (c *connWriter) Init(jsonConfig string) error { return json.Unmarshal([]byte(jsonConfig), c) } -// WriteMsg write message in connection. -// if connection is down, try to re-connect. +// WriteMsg writes message in connection. +// If connection is down, try to re-connect. func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { if level > c.Level { return nil diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 3dcaee1d..b2cc2907 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -26,7 +26,7 @@ import ( // brush is a color join function type brush func(string) string -// newBrush return a fix color Brush +// newBrush returns a fix color Brush func newBrush(color string) brush { pre := "\033[" reset := "\033[0m" @@ -53,7 +53,7 @@ type consoleWriter struct { Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } -// NewConsole create ConsoleWriter returning as LoggerInterface. +// NewConsole creates ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { cw := &consoleWriter{ lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), @@ -63,8 +63,8 @@ func NewConsole() Logger { return cw } -// Init init console logger. -// jsonConfig like '{"level":LevelTrace}'. +// Init initianlizes the console logger. +// jsonConfig must be in the format '{"level":LevelTrace}' func (c *consoleWriter) Init(jsonConfig string) error { if len(jsonConfig) == 0 { return nil @@ -72,7 +72,7 @@ func (c *consoleWriter) Init(jsonConfig string) error { return json.Unmarshal([]byte(jsonConfig), c) } -// WriteMsg write message in console. +// WriteMsg writes message in console. func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { if level > c.Level { return nil diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index af6a7892..7542b577 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -15,7 +15,7 @@ import ( "github.com/astaxie/beego/pkg/logs" ) -// NewES return a LoggerInterface +// NewES returns a LoggerInterface func NewES() logs.Logger { cw := &esLogger{ Level: logs.LevelDebug, @@ -59,7 +59,7 @@ func (el *esLogger) Init(jsonconfig string) error { return nil } -// WriteMsg will write the msg and level into es +// WriteMsg writes the msg and level into es func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { if level > el.Level { return nil diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 40a3572a..fbe10b55 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -30,7 +30,7 @@ import ( ) // fileLogWriter implements LoggerInterface. -// It writes messages by lines limit, file size limit, or time frequency. +// Writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize // The opened file @@ -71,7 +71,7 @@ type fileLogWriter struct { fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix } -// newFileWriter create a FileLogWriter returning as LoggerInterface. +// newFileWriter creates a FileLogWriter returning as LoggerInterface. func newFileWriter() Logger { w := &fileLogWriter{ Daily: true, @@ -143,7 +143,7 @@ func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { } -// WriteMsg write logger message into file. +// WriteMsg writes logger message into file. func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { if level > w.Level { return nil @@ -286,7 +286,7 @@ func (w *fileLogWriter) lines() (int, error) { return count, nil } -// DoRotate means it need to write file in new file. +// DoRotate means it needs to write logs into a new file. // new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) func (w *fileLogWriter) doRotate(logTime time.Time) error { // file exists @@ -397,7 +397,7 @@ func (w *fileLogWriter) Destroy() { w.fileWriter.Close() } -// Flush flush file logger. +// Flush flushes file logger. // there are no buffering messages in file logger in memory. // flush file means sync file from disk. func (w *fileLogWriter) Flush() { diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 88ba0f9a..71e7e2bf 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -18,7 +18,7 @@ type JLWriter struct { Level int `json:"level"` } -// newJLWriter create jiaoliao writer. +// newJLWriter creates jiaoliao writer. func newJLWriter() Logger { return &JLWriter{Level: LevelTrace} } @@ -28,8 +28,8 @@ func (s *JLWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), s) } -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. +// WriteMsg writes message in smtp writer. +// Sends an email with subject and only this message. func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { if level > s.Level { return nil diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 39c006d2..4824918b 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -108,7 +108,7 @@ func Register(name string, log newLoggerFunc) { } // BeeLogger is default logger in beego application. -// it can contain several providers and log message into all providers. +// Can contain several providers and log message into all providers. type BeeLogger struct { lock sync.Mutex level int @@ -140,7 +140,7 @@ type logMsg struct { var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. -// channelLen means the number of messages in chan(used where asynchronous is true). +// channelLen: the number of messages in chan(used where asynchronous is true). // if the buffering chan is full, logger adapters write to file or other way. func NewLogger(channelLens ...int64) *BeeLogger { bl := new(BeeLogger) @@ -155,7 +155,7 @@ func NewLogger(channelLens ...int64) *BeeLogger { return bl } -// Async set the log to asynchronous and start the goroutine +// Async sets the log to asynchronous and start the goroutine func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { bl.lock.Lock() defer bl.lock.Unlock() @@ -178,7 +178,7 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { } // SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. +// config must in in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { config := append(configs, "{}")[0] for _, l := range bl.outputs { @@ -203,7 +203,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } // SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. +// config must in in JSON format like {"interval":360}} func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { bl.lock.Lock() defer bl.lock.Unlock() @@ -214,7 +214,7 @@ func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { return bl.setLogger(adapterName, configs...) } -// DelLogger remove a logger adapter in BeeLogger. +// DelLogger removes a logger adapter in BeeLogger. func (bl *BeeLogger) DelLogger(adapterName string) error { bl.lock.Lock() defer bl.lock.Unlock() @@ -306,9 +306,9 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error return nil } -// SetLevel Set log message level. +// SetLevel sets log message level. // If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), -// log providers will not even be sent the message. +// log providers will not be sent the message. func (bl *BeeLogger) SetLevel(l int) { bl.level = l } diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index 1cd2e5ae..e78eeab6 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -14,7 +14,7 @@ type SLACKWriter struct { Level int `json:"level"` } -// newSLACKWriter create jiaoliao writer. +// newSLACKWriter creates jiaoliao writer. func newSLACKWriter() Logger { return &SLACKWriter{Level: LevelTrace} } @@ -25,7 +25,7 @@ func (s *SLACKWriter) Init(jsonconfig string) error { } // WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. +// Sends an email with subject and only this message. func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { if level > s.Level { return nil diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 6208d7b8..720c2d25 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -35,7 +35,7 @@ type SMTPWriter struct { Level int `json:"level"` } -// NewSMTPWriter create smtp writer. +// NewSMTPWriter creates the smtp writer. func newSMTPWriter() Logger { return &SMTPWriter{Level: LevelTrace} } @@ -115,8 +115,8 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return client.Quit() } -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. +// WriteMsg writes message in smtp writer. +// Sends an email with subject and only this message. func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { if level > s.Level { return nil From 63b3fc4a996ebd40e91c8d00c2e0ba5562c6e1e1 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 6 Aug 2020 16:09:06 +0100 Subject: [PATCH 075/207] Fix retry amount comment --- pkg/httplib/httplib.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go index 1438a881..7255b2ca 100644 --- a/pkg/httplib/httplib.go +++ b/pkg/httplib/httplib.go @@ -195,7 +195,7 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { } // Retries sets Retries times. -// default is 0 (never retries) +// default is 0 (never retry) // -1 retry indefinitely (forever) // Other numbers specify the exact retry amount func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { From 08cec9178fb8f2ee2827a4e37c869edeff3ffdf6 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 7 Aug 2020 13:45:24 +0000 Subject: [PATCH 076/207] Orm filter support --- pkg/orm/do_nothing_omr_test.go | 134 +++++++ pkg/orm/do_nothing_orm.go | 178 +++++++++ pkg/orm/filter.go | 32 ++ pkg/orm/filter_orm_decorator.go | 519 +++++++++++++++++++++++++++ pkg/orm/filter_orm_decorator_test.go | 432 ++++++++++++++++++++++ pkg/orm/filter_test.go | 31 ++ pkg/orm/invocation.go | 48 +++ pkg/orm/models.go | 9 + pkg/orm/orm_test.go | 2 + pkg/orm/types.go | 10 +- 10 files changed, 1391 insertions(+), 4 deletions(-) create mode 100644 pkg/orm/do_nothing_omr_test.go create mode 100644 pkg/orm/do_nothing_orm.go create mode 100644 pkg/orm/filter.go create mode 100644 pkg/orm/filter_orm_decorator.go create mode 100644 pkg/orm/filter_orm_decorator_test.go create mode 100644 pkg/orm/filter_test.go create mode 100644 pkg/orm/invocation.go diff --git a/pkg/orm/do_nothing_omr_test.go b/pkg/orm/do_nothing_omr_test.go new file mode 100644 index 00000000..92cde38b --- /dev/null +++ b/pkg/orm/do_nothing_omr_test.go @@ -0,0 +1,134 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingOrm(t *testing.T) { + o := &DoNothingOrm{} + err := o.DoTxWithCtxAndOpts(nil, nil, nil) + assert.Nil(t, err) + + err = o.DoTxWithCtx(nil, nil) + assert.Nil(t, err) + + err = o.DoTx(nil) + assert.Nil(t, err) + + err = o.DoTxWithOpts(nil, nil) + assert.Nil(t, err) + + assert.Nil(t, o.Driver()) + + assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) + assert.Nil(t, o.QueryM2M(nil, "")) + assert.Nil(t, o.ReadWithCtx(nil, nil)) + assert.Nil(t, o.Read(nil)) + + txOrm, err := o.BeginWithCtxAndOpts(nil, nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.BeginWithCtx(nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.BeginWithOpts(nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.Begin() + assert.Nil(t, err) + assert.Nil(t, txOrm) + + assert.Nil(t, o.RawWithCtx(nil, "")) + assert.Nil(t, o.Raw("")) + + i, err := o.InsertMulti(0, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.Insert(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertOrUpdateWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertOrUpdate(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertMultiWithCtx(nil, 0, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.LoadRelatedWithCtx(nil, nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.LoadRelated(nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + assert.Nil(t, o.QueryTableWithCtx(nil, nil)) + assert.Nil(t, o.QueryTable(nil)) + + assert.Nil(t, o.Read(nil)) + assert.Nil(t, o.ReadWithCtx(nil, nil)) + assert.Nil(t, o.ReadForUpdateWithCtx(nil, nil)) + assert.Nil(t, o.ReadForUpdate(nil)) + + ok, i, err := o.ReadOrCreate(nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + assert.False(t, ok) + + ok, i, err = o.ReadOrCreateWithCtx(nil, nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + assert.False(t, ok) + + i, err = o.Delete(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.DeleteWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.Update(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.UpdateWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + assert.Nil(t, o.DBStats()) + + to := &DoNothingTxOrm{} + assert.Nil(t, to.Commit()) + assert.Nil(t, to.Rollback()) +} diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go new file mode 100644 index 00000000..87b0a2ae --- /dev/null +++ b/pkg/orm/do_nothing_orm.go @@ -0,0 +1,178 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" +) + +// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation +// I think golang mocking interface is hard to use +// this may help you to integrate with Ormer + +var _ Ormer = new(DoNothingOrm) + +type DoNothingOrm struct { +} + +func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadForUpdate(md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return false, 0, nil +} + +func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + return false, 0, nil +} + +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { + return nil +} + +func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { + return nil +} + +func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { + return nil +} + +func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { + return nil +} + +func (d *DoNothingOrm) DBStats() *sql.DBStats { + return nil +} + +func (d *DoNothingOrm) Insert(md interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertMulti(bulk int, mds interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Update(md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Delete(md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Raw(query string, args ...interface{}) RawSeter { + return nil +} + +func (d *DoNothingOrm) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + return nil +} + +func (d *DoNothingOrm) Driver() Driver { + return nil +} + +func (d *DoNothingOrm) Begin() (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) DoTx(task func(txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return nil +} + +// DoNothingTxOrm is similar with DoNothingOrm, usually you use it to test +type DoNothingTxOrm struct { + DoNothingOrm +} + +func (d *DoNothingTxOrm) Commit() error { + return nil +} + +func (d *DoNothingTxOrm) Rollback() error { + return nil +} diff --git a/pkg/orm/filter.go b/pkg/orm/filter.go new file mode 100644 index 00000000..9676e4af --- /dev/null +++ b/pkg/orm/filter.go @@ -0,0 +1,32 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" +) + +type FilterChain func(next Filter) Filter + +type Filter func(ctx context.Context, inv *Invocation) + +var globalFilterChains = make([]FilterChain, 0, 4) + +// AddGlobalFilterChain adds a new FilterChain +// All orm instances built after this invocation will use this filterChain, +// but instances built before this invocation will not be affected +func AddGlobalFilterChain(filterChain FilterChain) { + globalFilterChains = append(globalFilterChains, filterChain) +} \ No newline at end of file diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go new file mode 100644 index 00000000..eb26ea68 --- /dev/null +++ b/pkg/orm/filter_orm_decorator.go @@ -0,0 +1,519 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" +) + +const TxNameKey = "TxName" + +type filterOrmDecorator struct { + ormer + TxBeginner + TxCommitter + + root Filter + + insideTx bool + txStartTime time.Time + txName string +} + +func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { + res := &filterOrmDecorator{ + ormer: delegate, + TxBeginner: delegate, + root: func(ctx context.Context, inv *Invocation) { + inv.execute() + }, + } + + for i := len(filterChains) - 1; i >= 0; i-- { + node := filterChains[i] + res.root = node(res.root) + } + return res +} + +func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrmer { + res := &filterOrmDecorator{ + ormer: delegate, + TxCommitter: delegate, + root: root, + insideTx: true, + txStartTime: time.Now(), + txName: txName, + } + return res +} + +func (f *filterOrmDecorator) Read(md interface{}, cols ...string) error { + return f.ReadWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) (err error) { + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "ReadWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + err = f.ormer.ReadWithCtx(ctx, md, cols...) + }, + } + f.root(ctx, inv) + return err +} + +func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error { + return f.ReadForUpdateWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + var err error + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "ReadForUpdateWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + err = f.ormer.ReadForUpdateWithCtx(ctx, md, cols...) + }, + } + f.root(ctx, inv) + return err +} + +func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return f.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} + +func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + var ( + ok bool + res int64 + err error + ) + + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "ReadOrCreateWithCtx", + Args: []interface{}{md, col1, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + ok, res, err = f.ormer.ReadOrCreateWithCtx(ctx, md, col1, cols...) + }, + } + f.root(ctx, inv) + return ok, res, err +} + +func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + return f.LoadRelatedWithCtx(context.Background(), md, name, args...) +} + +func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { + var ( + res int64 + err error + ) + + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "LoadRelatedWithCtx", + Args: []interface{}{md, name, args}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.LoadRelatedWithCtx(ctx, md, name, args...) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { + return f.QueryM2MWithCtx(context.Background(), md, name) +} + +func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { + var ( + res QueryM2Mer + ) + + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "QueryM2MWithCtx", + Args: []interface{}{md, name}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res = f.ormer.QueryM2MWithCtx(ctx, md, name) + }, + } + f.root(ctx, inv) + return res +} + +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { + return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +} + +func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { + var ( + res QuerySeter + name string + md interface{} + mi *modelInfo + ) + + if table, ok := ptrStructOrTableName.(string); ok { + name = table + } else { + name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + md = ptrStructOrTableName + } + + if m, ok := modelCache.getByFullName(name); ok { + mi = m + } + + inv := &Invocation{ + Method: "QueryTableWithCtx", + Args: []interface{}{ptrStructOrTableName}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + Md: md, + mi: mi, + f: func() { + res = f.ormer.QueryTableWithCtx(ctx, ptrStructOrTableName) + }, + } + f.root(ctx, inv) + return res +} + +func (f *filterOrmDecorator) DBStats() *sql.DBStats { + var ( + res *sql.DBStats + ) + inv := &Invocation{ + Method: "DBStats", + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res = f.ormer.DBStats() + }, + } + f.root(context.Background(), inv) + return res +} + +func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) { + return f.InsertWithCtx(context.Background(), md) +} + +func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + var ( + res int64 + err error + ) + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "InsertWithCtx", + Args: []interface{}{md}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.InsertWithCtx(ctx, md) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + return f.InsertOrUpdateWithCtx(context.Background(), md, colConflitAndArgs...) +} + +func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + var ( + res int64 + err error + ) + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "InsertOrUpdateWithCtx", + Args: []interface{}{md, colConflitAndArgs}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.InsertOrUpdateWithCtx(ctx, md, colConflitAndArgs...) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, error) { + return f.InsertMultiWithCtx(context.Background(), bulk, mds) +} + +// InsertMultiWithCtx uses the first element's model info +func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + var ( + res int64 + err error + md interface{} + mi *modelInfo + ) + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + if (sind.Kind() == reflect.Array || sind.Kind() == reflect.Slice) && sind.Len() > 0 { + ind := reflect.Indirect(sind.Index(0)) + md = ind.Interface() + mi, _ = modelCache.getByMd(md) + } + + inv := &Invocation{ + Method: "InsertMultiWithCtx", + Args: []interface{}{bulk, mds}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.InsertMultiWithCtx(ctx, bulk, mds) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, error) { + return f.UpdateWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + var ( + res int64 + err error + ) + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "UpdateWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.UpdateWithCtx(ctx, md, cols...) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, error) { + return f.DeleteWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + var ( + res int64 + err error + ) + mi, _ := modelCache.getByMd(md) + inv := &Invocation{ + Method: "DeleteWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.ormer.DeleteWithCtx(ctx, md, cols...) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter { + return f.RawWithCtx(context.Background(), query, args...) +} + +func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + var ( + res RawSeter + ) + inv := &Invocation{ + Method: "RawWithCtx", + Args: []interface{}{query, args}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res = f.ormer.RawWithCtx(ctx, query, args...) + }, + } + f.root(ctx, inv) + return res +} + +func (f *filterOrmDecorator) Driver() Driver { + var ( + res Driver + ) + inv := &Invocation{ + Method: "Driver", + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res = f.ormer.Driver() + }, + } + f.root(context.Background(), inv) + return res +} + +func (f *filterOrmDecorator) Begin() (TxOrmer, error) { + return f.BeginWithCtxAndOpts(context.Background(), nil) +} + +func (f *filterOrmDecorator) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return f.BeginWithCtxAndOpts(ctx, nil) +} + +func (f *filterOrmDecorator) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return f.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + var ( + res TxOrmer + err error + ) + inv := &Invocation{ + Method: "BeginWithCtxAndOpts", + Args: []interface{}{opts}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func() { + res, err = f.TxBeginner.BeginWithCtxAndOpts(ctx, opts) + res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(ctx)) + }, + } + f.root(ctx, inv) + return res, err +} + +func (f *filterOrmDecorator) DoTx(task func(txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(context.Background(), nil, task) +} + +func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + var ( + err error + ) + + inv := &Invocation{ + Method: "DoTxWithCtxAndOpts", + Args: []interface{}{opts, task}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: getTxNameFromCtx(ctx), + f: func() { + err = f.TxBeginner.DoTxWithCtxAndOpts(ctx, opts, task) + }, + } + f.root(ctx, inv) + return err +} + +func (f *filterOrmDecorator) Commit() error { + var ( + err error + ) + inv := &Invocation{ + Method: "Commit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func() { + err = f.TxCommitter.Commit() + }, + } + f.root(context.Background(), inv) + return err +} + +func (f *filterOrmDecorator) Rollback() error { + var ( + err error + ) + inv := &Invocation{ + Method: "Rollback", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func() { + err = f.TxCommitter.Rollback() + }, + } + f.root(context.Background(), inv) + return err +} + +func getTxNameFromCtx(ctx context.Context) string { + txName := "" + if n, ok := ctx.Value(TxNameKey).(string); ok { + txName = n + } + return txName +} + diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go new file mode 100644 index 00000000..d1099eaf --- /dev/null +++ b/pkg/orm/filter_orm_decorator_test.go @@ -0,0 +1,432 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilterOrmDecorator_Read(t *testing.T) { + + register() + + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "ReadWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + next(ctx, inv) + } + }) + + fte := &FilterTestEntity{} + err := od.Read(fte) + assert.NotNil(t, err) + assert.Equal(t, "read error", err.Error()) +} + +func TestFilterOrmDecorator_BeginTx(t *testing.T) { + register() + + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + if inv.Method == "BeginWithCtxAndOpts" { + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + } else if inv.Method == "Commit" { + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "Commit_tx", inv.TxName) + assert.Equal(t, "", inv.GetTableName()) + assert.True(t, inv.InsideTx) + } else if inv.Method == "Rollback" { + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "Rollback_tx", inv.TxName) + assert.Equal(t, "", inv.GetTableName()) + assert.True(t, inv.InsideTx) + } else { + t.Fail() + } + + next(ctx, inv) + } + }) + to, err := od.Begin() + assert.True(t, validateBeginResult(t, to, err)) + + to, err = od.BeginWithOpts(nil) + assert.True(t, validateBeginResult(t, to, err)) + + ctx := context.WithValue(context.Background(), TxNameKey, "Commit_tx") + to, err = od.BeginWithCtx(ctx) + assert.True(t, validateBeginResult(t, to, err)) + + err = to.Commit() + assert.NotNil(t, err) + assert.Equal(t, "commit", err.Error()) + + ctx = context.WithValue(context.Background(), TxNameKey, "Rollback_tx") + to, err = od.BeginWithCtxAndOpts(ctx, nil) + assert.True(t, validateBeginResult(t, to, err)) + + err = to.Rollback() + assert.NotNil(t, err) + assert.Equal(t, "rollback", err.Error()) +} + +func TestFilterOrmDecorator_DBStats(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "DBStats", inv.Method) + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + next(ctx, inv) + } + }) + res := od.DBStats() + assert.NotNil(t, res) + assert.Equal(t, -1, res.MaxOpenConnections) +} + +func TestFilterOrmDecorator_Delete(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "DeleteWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + next(ctx, inv) + } + }) + res, err := od.Delete(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "delete error", err.Error()) + assert.Equal(t, int64(-2), res) +} + +func TestFilterOrmDecorator_DoTx(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + + err := od.DoTx(func(txOrm TxOrmer) error { + return errors.New("tx error") + }) + assert.NotNil(t, err) + assert.Equal(t, "tx error", err.Error()) + + err = od.DoTxWithCtx(context.Background(), func(txOrm TxOrmer) error { + return errors.New("tx ctx error") + }) + assert.NotNil(t, err) + assert.Equal(t, "tx ctx error", err.Error()) + + err = od.DoTxWithOpts(nil, func(txOrm TxOrmer) error { + return errors.New("tx opts error") + }) + assert.NotNil(t, err) + assert.Equal(t, "tx opts error", err.Error()) + + od = NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.Equal(t, "do tx name", inv.TxName) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + + ctx := context.WithValue(context.Background(), TxNameKey, "do tx name") + err = od.DoTxWithCtxAndOpts(ctx, nil, func(txOrm TxOrmer) error { + return errors.New("tx ctx opts error") + }) + assert.NotNil(t, err) + assert.Equal(t, "tx ctx opts error", err.Error()) +} + +func TestFilterOrmDecorator_Driver(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "Driver", inv.Method) + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + res := od.Driver() + assert.Nil(t, res) +} + +func TestFilterOrmDecorator_Insert(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "InsertWithCtx", inv.Method) + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + + i, err := od.Insert(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "insert error", err.Error()) + assert.Equal(t, int64(100), i) +} + +func TestFilterOrmDecorator_InsertMulti(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "InsertMultiWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + + bulk := []*FilterTestEntity{&FilterTestEntity{}, &FilterTestEntity{}} + i, err := od.InsertMulti(2, bulk) + assert.NotNil(t, err) + assert.Equal(t, "insert multi error", err.Error()) + assert.Equal(t, int64(2), i) +} + +func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "InsertOrUpdateWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + i, err := od.InsertOrUpdate(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "insert or update error", err.Error()) + assert.Equal(t, int64(1), i) +} + +func TestFilterOrmDecorator_LoadRelated(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "LoadRelatedWithCtx", inv.Method) + assert.Equal(t, 3, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + i, err := od.LoadRelated(&FilterTestEntity{}, "hello") + assert.NotNil(t, err) + assert.Equal(t, "load related error", err.Error()) + assert.Equal(t, int64(99), i) +} + +func TestFilterOrmDecorator_QueryM2M(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + res := od.QueryM2M(&FilterTestEntity{}, "hello") + assert.Nil(t, res) +} + +func TestFilterOrmDecorator_QueryTable(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + res := od.QueryTable(&FilterTestEntity{}) + assert.Nil(t, res) +} + +func TestFilterOrmDecorator_Raw(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "RawWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + res := od.Raw("hh") + assert.Nil(t, res) +} + +func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "ReadForUpdateWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + err := od.ReadForUpdate(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "read for update error", err.Error()) +} + +func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + assert.Equal(t, "ReadOrCreateWithCtx", inv.Method) + assert.Equal(t, 3, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + next(ctx, inv) + } + }) + ok, i, err := od.ReadOrCreate(&FilterTestEntity{}, "name") + assert.NotNil(t, err) + assert.Equal(t, "read or create error", err.Error()) + assert.True(t, ok) + assert.Equal(t, int64(13), i) +} + +// filterMockOrm is only used in this test file +type filterMockOrm struct { + DoNothingOrm +} + +func (f *filterMockOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + return true, 13, errors.New("read or create error") +} + +func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return errors.New("read for update error") +} + +func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { + return 99, errors.New("load related error") +} + +func (f *filterMockOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + return 1, errors.New("insert or update error") +} + +func (f *filterMockOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + return 2, errors.New("insert multi error") +} + +func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + return 100, errors.New("insert error") +} + +func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return task(nil) +} + +func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return -2, errors.New("delete error") +} + +func (f *filterMockOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + return &filterMockOrm{}, errors.New("begin tx") +} + +func (f *filterMockOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return errors.New("read error") +} + +func (f *filterMockOrm) Commit() error { + return errors.New("commit") +} + +func (f *filterMockOrm) Rollback() error { + return errors.New("rollback") +} + +func (f *filterMockOrm) DBStats() *sql.DBStats { + return &sql.DBStats{ + MaxOpenConnections: -1, + } +} + +func validateBeginResult(t *testing.T, to TxOrmer, err error) bool { + assert.NotNil(t, err) + assert.Equal(t, "begin tx", err.Error()) + _, ok := to.(*filterOrmDecorator).TxCommitter.(*filterMockOrm) + assert.True(t, ok) + return true +} + +var filterTestEntityRegisterOnce sync.Once + +type FilterTestEntity struct { + ID int + Name string +} + +func register() { + filterTestEntityRegisterOnce.Do(func() { + RegisterModel(&FilterTestEntity{}) + }) +} + +func (f *FilterTestEntity) TableName() string { + return "FILTER_TEST" +} diff --git a/pkg/orm/filter_test.go b/pkg/orm/filter_test.go new file mode 100644 index 00000000..0f2944c7 --- /dev/null +++ b/pkg/orm/filter_test.go @@ -0,0 +1,31 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddGlobalFilterChain(t *testing.T) { + AddGlobalFilterChain(func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) { + + } + }) + assert.Equal(t, 1, len(globalFilterChains)) +} diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go new file mode 100644 index 00000000..1c9fee09 --- /dev/null +++ b/pkg/orm/invocation.go @@ -0,0 +1,48 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "time" +) + +// Invocation represents an "Orm" invocation +type Invocation struct { + Method string + // Md may be nil in some cases. It depends on method + Md interface{} + // the args are all arguments except context.Context + Args []interface{} + + mi *modelInfo + // f is the Orm operation + f func() + + // insideTx indicates whether this is inside a transaction + InsideTx bool + TxStartTime time.Time + TxName string +} + +func (inv *Invocation) GetTableName() string { + if inv.mi != nil{ + return inv.mi.table + } + return "" +} + +func (inv *Invocation) execute() { + inv.f() +} diff --git a/pkg/orm/models.go b/pkg/orm/models.go index 4776bcba..c8fbcced 100644 --- a/pkg/orm/models.go +++ b/pkg/orm/models.go @@ -15,6 +15,7 @@ package orm import ( + "reflect" "sync" ) @@ -73,6 +74,14 @@ func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { return } +func (mc *_modelCache) getByMd(md interface{}) (*modelInfo, bool) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := getFullName(typ) + return mc.getByFullName(name) +} + // set model info to collection func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { mii := mc.cache[table] diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index e3dafecd..f5242a46 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2486,3 +2486,5 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) } } + + diff --git a/pkg/orm/types.go b/pkg/orm/types.go index cb0f97cc..9624fd94 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -204,17 +204,19 @@ type DriverGetter interface { Driver() Driver } -type Ormer interface { +type ormer interface { DQL DML DriverGetter +} + +type Ormer interface { + ormer TxBeginner } type TxOrmer interface { - DQL - DML - DriverGetter + ormer TxCommitter } From 2fd65a469c25edf2531f58e239c6e92a636862f6 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 7 Aug 2020 14:14:07 +0000 Subject: [PATCH 077/207] Support prometheus --- pkg/orm/filter/prometheus/filter.go | 84 ++++++++++++++++++++++++ pkg/orm/filter/prometheus/filter_test.go | 60 +++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 pkg/orm/filter/prometheus/filter.go create mode 100644 pkg/orm/filter/prometheus/filter_test.go diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/orm/filter/prometheus/filter.go new file mode 100644 index 00000000..9f177deb --- /dev/null +++ b/pkg/orm/filter/prometheus/filter.go @@ -0,0 +1,84 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + beego "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/orm" +) + +// FilterChainBuilder is an extension point, +// when we want to support some configuration, +// please use this structure +type FilterChainBuilder struct { + summaryVec prometheus.ObserverVec +} + +func NewFilterChainBuilder() *FilterChainBuilder { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "orm_operation", + ConstLabels: map[string]string{ + "server": beego.BConfig.ServerName, + "env": beego.BConfig.RunMode, + "appname": beego.BConfig.AppName, + }, + Help: "The statics info for orm operation", + }, []string{"method", "name", "duration", "insideTx", "txName"}) + + prometheus.MustRegister(summaryVec) + return &FilterChainBuilder{ + summaryVec: summaryVec, + } +} + +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) { + startTime := time.Now() + next(ctx, inv) + endTime := time.Now() + dur := (endTime.Sub(startTime)) / time.Millisecond + + // if the TPS is too large, here may be some problem + // thinking about using goroutine pool + go builder.report(ctx, inv, dur) + } +} + +func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocation, dur time.Duration) { + // start a transaction, we don't record it + if strings.HasPrefix(inv.Method, "Begin") { + return + } + if inv.Method == "Commit" || inv.Method == "Rollback" { + builder.reportTxn(ctx, inv) + return + } + builder.summaryVec.WithLabelValues(inv.Method, inv.GetTableName(), strconv.Itoa(int(dur)), + strconv.FormatBool(inv.InsideTx), inv.TxName) +} + +func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) { + dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond + builder.summaryVec.WithLabelValues(inv.Method, inv.TxName, strconv.Itoa(int(dur)), + strconv.FormatBool(inv.InsideTx), inv.TxName) +} diff --git a/pkg/orm/filter/prometheus/filter_test.go b/pkg/orm/filter/prometheus/filter_test.go new file mode 100644 index 00000000..a71e8f50 --- /dev/null +++ b/pkg/orm/filter/prometheus/filter_test.go @@ -0,0 +1,60 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/orm" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + builder := NewFilterChainBuilder() + assert.NotNil(t, builder.summaryVec) + + filter := builder.FilterChain(func(ctx context.Context, inv *orm.Invocation) { + inv.Method = "coming" + }) + assert.NotNil(t, filter) + + inv := &orm.Invocation{} + filter(context.Background(), inv) + assert.Equal(t, "coming", inv.Method) + + inv = &orm.Invocation{ + Method: "Hello", + TxStartTime: time.Now(), + } + builder.reportTxn(context.Background(), inv) + + inv = &orm.Invocation{ + Method: "Begin", + } + + ctx := context.Background() + // it will be ignored + builder.report(ctx, inv, time.Second) + + inv.Method = "Commit" + builder.report(ctx, inv, time.Second) + + inv.Method = "Update" + builder.report(ctx, inv, time.Second) + +} From 993ccac2bd41c20ff7a0266c28175c5783dd553f Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Thu, 6 Aug 2020 09:39:12 +0800 Subject: [PATCH 078/207] fix comment router generate issue --- go.mod | 1 + pkg/beego.go | 1 + pkg/config.go | 2 ++ pkg/hooks.go | 10 ++++++++++ pkg/parser.go | 33 ++++++++++++++++++++------------- pkg/router.go | 41 ----------------------------------------- 6 files changed, 34 insertions(+), 54 deletions(-) diff --git a/go.mod b/go.mod index a6c27488..3ad8576a 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect + golang.org/x/tools v0.0.0-20200117065230-39095c1d176c google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 ) diff --git a/pkg/beego.go b/pkg/beego.go index 8ebe0bab..c08ae528 100644 --- a/pkg/beego.go +++ b/pkg/beego.go @@ -97,6 +97,7 @@ func initBeforeHTTPRun() { registerTemplate, registerAdmin, registerGzip, + registerCommentRouter, ) for _, hk := range hooks { diff --git a/pkg/config.go b/pkg/config.go index 2a5dec25..0cfb7a4c 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -86,6 +86,7 @@ type WebConfig struct { TemplateLeft string TemplateRight string ViewsPath string + CommentRouterPath string EnableXSRF bool XSRFKey string XSRFExpire int @@ -245,6 +246,7 @@ func newBConfig() *Config { TemplateLeft: "{{", TemplateRight: "}}", ViewsPath: "views", + CommentRouterPath: "controllers", EnableXSRF: false, XSRFKey: "beegoxsrf", XSRFExpire: 0, diff --git a/pkg/hooks.go b/pkg/hooks.go index 8c782383..f511e216 100644 --- a/pkg/hooks.go +++ b/pkg/hooks.go @@ -102,3 +102,13 @@ func registerGzip() error { } return nil } + +func registerCommentRouter() error { + if BConfig.RunMode == DEV { + if err := parserPkg(filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)); err != nil { + return err + } + } + + return nil +} \ No newline at end of file diff --git a/pkg/parser.go b/pkg/parser.go index 606be190..d7ab45f0 100644 --- a/pkg/parser.go +++ b/pkg/parser.go @@ -19,8 +19,7 @@ import ( "errors" "fmt" "go/ast" - "go/parser" - "go/token" + "golang.org/x/tools/go/packages" "io/ioutil" "os" "path/filepath" @@ -76,7 +75,7 @@ func init() { pkgLastupdate = make(map[string]int64) } -func parserPkg(pkgRealpath, pkgpath string) error { +func parserPkg(pkgRealpath string) error { rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" @@ -85,24 +84,23 @@ func parserPkg(pkgRealpath, pkgpath string) error { return nil } genInfoList = make(map[string][]ControllerComments) - fileSet := token.NewFileSet() - astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { - name := info.Name() - return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") - }, parser.ParseComments) + pkgs, err := packages.Load(&packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedSyntax, + Dir: pkgRealpath, + }, "./...") if err != nil { return err } - for _, pkg := range astPkgs { - for _, fl := range pkg.Files { + for _, pkg := range pkgs { + for _, fl := range pkg.Syntax { for _, d := range fl.Decls { switch specDecl := d.(type) { case *ast.FuncDecl: if specDecl.Recv != nil { exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) + parserComments(specDecl, fmt.Sprint(exp.X), pkg.PkgPath) } } } @@ -566,8 +564,17 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { return lastupdate, err } for _, f := range fl { - if lastupdate < f.ModTime().UnixNano() { - lastupdate = f.ModTime().UnixNano() + var t int64 + if f.IsDir() { + t, err = getpathTime(filepath.Join(pkgRealpath, f.Name())) + if err != nil { + return lastupdate, err + } + } else { + t = f.ModTime().UnixNano() + } + if lastupdate < t { + lastupdate = t } } return lastupdate, nil diff --git a/pkg/router.go b/pkg/router.go index b0c23003..8caba94a 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -18,9 +18,7 @@ import ( "errors" "fmt" "net/http" - "os" "path" - "path/filepath" "reflect" "strconv" "strings" @@ -257,45 +255,6 @@ func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerIn // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller // Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) func (p *ControllerRegister) Include(cList ...ControllerInterface) { - if BConfig.RunMode == DEV { - skip := make(map[string]bool, 10) - wgopath := utils.GetGOPATHs() - go111module := os.Getenv(`GO111MODULE`) - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - // for go modules - if go111module == `on` { - pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) - if utils.FileExists(pkgpath) { - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } else { - if len(wgopath) == 0 { - panic("you are in dev mode. So please set gopath") - } - pkgpath := "" - for _, wg := range wgopath { - wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) - if utils.FileExists(wg) { - pkgpath = wg - break - } - } - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } - } for _, c := range cList { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() From f9a3eae9d5f1ae3504482e6e2c759c5d82a8457a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 8 Aug 2020 13:17:49 +0000 Subject: [PATCH 079/207] Move init so it will be default implementation of config --- pkg/config/{ini => }/ini.go | 14 ++++++-------- pkg/config/{ini => }/ini_test.go | 8 +++----- 2 files changed, 9 insertions(+), 13 deletions(-) rename pkg/config/{ini => }/ini.go (97%) rename pkg/config/{ini => }/ini_test.go (95%) diff --git a/pkg/config/ini/ini.go b/pkg/config/ini.go similarity index 97% rename from pkg/config/ini/ini.go rename to pkg/config/ini.go index 17408d85..f5921308 100644 --- a/pkg/config/ini/ini.go +++ b/pkg/config/ini.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package ini +package config import ( "bufio" @@ -26,8 +26,6 @@ import ( "strconv" "strings" "sync" - - "github.com/astaxie/beego/pkg/config" ) var ( @@ -47,7 +45,7 @@ type IniConfig struct { } // Parse creates a new Config and parses the file configuration from the named file. -func (ini *IniConfig) Parse(name string) (config.Configer, error) { +func (ini *IniConfig) Parse(name string) (Configer, error) { return ini.parseFile(name) } @@ -197,7 +195,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e val = bytes.Trim(val, `"`) } - cfg.data[section][key] = config.ExpandValueEnv(string(val)) + cfg.data[section][key] = ExpandValueEnv(string(val)) if comment.Len() > 0 { cfg.keyComment[section+"."+key] = comment.String() comment.Reset() @@ -210,7 +208,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e // ParseData parse ini the data // When include other.conf,other.conf is either absolute directory // or under beego in default temporary directory(/tmp/beego[-username]). -func (ini *IniConfig) ParseData(data []byte) (config.Configer, error) { +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { dir := "beego" currentUser, err := user.Current() if err == nil { @@ -235,7 +233,7 @@ type IniConfigContainer struct { // Bool returns the boolean value for a given key. func (c *IniConfigContainer) Bool(key string) (bool, error) { - return config.ParseBool(c.getdata(key)) + return ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. @@ -502,5 +500,5 @@ func (c *IniConfigContainer) getdata(key string) string { } func init() { - config.Register("ini", &IniConfig{}) + Register("ini", &IniConfig{}) } diff --git a/pkg/config/ini/ini_test.go b/pkg/config/ini_test.go similarity index 95% rename from pkg/config/ini/ini_test.go rename to pkg/config/ini_test.go index 70f1091d..ffcdb294 100644 --- a/pkg/config/ini/ini_test.go +++ b/pkg/config/ini_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package ini +package config import ( "fmt" @@ -20,8 +20,6 @@ import ( "os" "strings" "testing" - - "github.com/astaxie/beego/pkg/config" ) func TestIni(t *testing.T) { @@ -94,7 +92,7 @@ password = ${GOPATH} } f.Close() defer os.Remove("testini.conf") - iniconf, err := config.NewConfig("ini", "testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") if err != nil { t.Fatal(err) } @@ -167,7 +165,7 @@ httpport=8080 name=mysql ` ) - cfg, err := config.NewConfigData("ini", []byte(inicontext)) + cfg, err := NewConfigData("ini", []byte(inicontext)) if err != nil { t.Fatal(err) } From 2e192e1ed08e59ac60b9fa801ff07ccbb15cb27b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 8 Aug 2020 13:26:30 +0000 Subject: [PATCH 080/207] Depracated config module and recommend using pkg/config --- config/config.go | 17 +++++++++++++++++ config/env/env.go | 5 +++++ config/fake.go | 33 +++++++++++++++++---------------- config/ini.go | 20 ++++++++++++++++++++ config/json.go | 21 +++++++++++++++++++++ config/xml/xml.go | 20 ++++++++++++++++++++ config/yaml/yaml.go | 21 +++++++++++++++++++++ 7 files changed, 121 insertions(+), 16 deletions(-) diff --git a/config/config.go b/config/config.go index bfd79e85..f46f862b 100644 --- a/config/config.go +++ b/config/config.go @@ -48,22 +48,39 @@ import ( ) // Configer defines how to get and set value from configuration raw data. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type Configer interface { + // Deprecated: using pkg/config, we will delete this in v2.1.0 Set(key, val string) error //support section::key type in given key when using ini type. + // Deprecated: using pkg/config, we will delete this in v2.1.0 String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + // Deprecated: using pkg/config, we will delete this in v2.1.0 Strings(key string) []string //get string slice + // Deprecated: using pkg/config, we will delete this in v2.1.0 Int(key string) (int, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 Int64(key string) (int64, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 Bool(key string) (bool, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 Float(key string) (float64, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultStrings(key string, defaultVal []string) []string //get string slice + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultInt(key string, defaultVal int) int + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultInt64(key string, defaultVal int64) int64 + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultBool(key string, defaultVal bool) bool + // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultFloat(key string, defaultVal float64) float64 + // Deprecated: using pkg/config, we will delete this in v2.1.0 DIY(key string) (interface{}, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 GetSection(section string) (map[string]string, error) + // Deprecated: using pkg/config, we will delete this in v2.1.0 SaveConfigFile(filename string) error } diff --git a/config/env/env.go b/config/env/env.go index 34f094fe..1a6c2527 100644 --- a/config/env/env.go +++ b/config/env/env.go @@ -36,6 +36,7 @@ func init() { // Get returns a value by key. // If the key does not exist, the default value will be returned. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func Get(key string, defVal string) string { if val := env.Get(key); val != nil { return val.(string) @@ -45,6 +46,7 @@ func Get(key string, defVal string) string { // MustGet returns a value by key. // If the key does not exist, it will return an error. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func MustGet(key string) (string, error) { if val := env.Get(key); val != nil { return val.(string), nil @@ -54,12 +56,14 @@ func MustGet(key string) (string, error) { // Set sets a value in the ENV copy. // This does not affect the child process environment. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func Set(key string, value string) { env.Set(key, value) } // MustSet sets a value in the ENV copy and the child process environment. // It returns an error in case the set operation failed. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func MustSet(key string, value string) error { err := os.Setenv(key, value) if err != nil { @@ -70,6 +74,7 @@ func MustSet(key string, value string) error { } // GetAll returns all keys/values in the current child process environment. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func GetAll() map[string]string { items := env.Items() envs := make(map[string]string, env.Count()) diff --git a/config/fake.go b/config/fake.go index d21ab820..07e56ce2 100644 --- a/config/fake.go +++ b/config/fake.go @@ -27,16 +27,16 @@ type fakeConfigContainer struct { func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Set(key, val string) error { c.data[strings.ToLower(key)] = val return nil } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) String(key string) string { return c.getData(key) } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -44,7 +44,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Strings(key string) []string { v := c.String(key) if v == "" { @@ -52,7 +52,7 @@ func (c *fakeConfigContainer) Strings(key string) []string { } return strings.Split(v, ";") } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -60,11 +60,11 @@ func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) [] } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -72,11 +72,11 @@ func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -84,11 +84,11 @@ func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getData(key)) } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) if err != nil { @@ -96,11 +96,11 @@ func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -108,18 +108,18 @@ func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float } return v } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } return nil, errors.New("key not find") } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } - +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") } @@ -127,6 +127,7 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error { var _ Configer = new(fakeConfigContainer) // NewFakeConfig return a fake Configer +// Deprecated: using pkg/config, we will delete this in v2.1.0 func NewFakeConfig() Configer { return &fakeConfigContainer{ data: make(map[string]string), diff --git a/config/ini.go b/config/ini.go index 002e5e05..1da293dc 100644 --- a/config/ini.go +++ b/config/ini.go @@ -41,10 +41,12 @@ var ( ) // IniConfig implements Config to parse ini file. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type IniConfig struct { } // Parse creates a new Config and parses the file configuration from the named file. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (ini *IniConfig) Parse(name string) (Configer, error) { return ini.parseFile(name) } @@ -208,6 +210,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e // ParseData parse ini the data // When include other.conf,other.conf is either absolute directory // or under beego in default temporary directory(/tmp/beego[-username]). +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (ini *IniConfig) ParseData(data []byte) (Configer, error) { dir := "beego" currentUser, err := user.Current() @@ -224,6 +227,7 @@ func (ini *IniConfig) ParseData(data []byte) (Configer, error) { // IniConfigContainer A Config represents the ini configuration. // When set and get value, support key as section:name type. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type IniConfigContainer struct { data map[string]map[string]string // section=> key:val sectionComment map[string]string // section : comment @@ -232,12 +236,14 @@ type IniConfigContainer struct { } // Bool returns the boolean value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) if err != nil { @@ -247,12 +253,14 @@ func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { } // Int returns the integer value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getdata(key)) } // DefaultInt returns the integer value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -262,12 +270,14 @@ func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { } // Int64 returns the int64 value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getdata(key), 10, 64) } // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -277,12 +287,14 @@ func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } // Float returns the float value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getdata(key), 64) } // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -292,12 +304,14 @@ func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float6 } // String returns the string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) String(key string) string { return c.getdata(key) } // DefaultString returns the string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -308,6 +322,7 @@ func (c *IniConfigContainer) DefaultString(key string, defaultval string) string // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Strings(key string) []string { v := c.String(key) if v == "" { @@ -318,6 +333,7 @@ func (c *IniConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -327,6 +343,7 @@ func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []s } // GetSection returns map for the given section +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v, nil @@ -337,6 +354,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro // SaveConfigFile save the config into file. // // BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) @@ -437,6 +455,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) Set(key, value string) error { c.Lock() defer c.Unlock() @@ -465,6 +484,7 @@ func (c *IniConfigContainer) Set(key, value string) error { } // DIY returns the raw value by a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil diff --git a/config/json.go b/config/json.go index c4ef25cd..74a50d34 100644 --- a/config/json.go +++ b/config/json.go @@ -26,10 +26,12 @@ import ( ) // JSONConfig is a json config parser and implements Config interface. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type JSONConfig struct { } // Parse returns a ConfigContainer with parsed json config map. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (js *JSONConfig) Parse(filename string) (Configer, error) { file, err := os.Open(filename) if err != nil { @@ -45,6 +47,7 @@ func (js *JSONConfig) Parse(filename string) (Configer, error) { } // ParseData returns a ConfigContainer with json string +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (js *JSONConfig) ParseData(data []byte) (Configer, error) { x := &JSONConfigContainer{ data: make(map[string]interface{}), @@ -66,12 +69,14 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) { // JSONConfigContainer A Config represents the json configuration. // Only when get value, support key as section:name type. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type JSONConfigContainer struct { data map[string]interface{} sync.RWMutex } // Bool returns the boolean value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Bool(key string) (bool, error) { val := c.getData(key) if val != nil { @@ -82,6 +87,7 @@ func (c *JSONConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { if v, err := c.Bool(key); err == nil { return v @@ -90,6 +96,7 @@ func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { } // Int returns the integer value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Int(key string) (int, error) { val := c.getData(key) if val != nil { @@ -105,6 +112,7 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { if v, err := c.Int(key); err == nil { return v @@ -113,6 +121,7 @@ func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { } // Int64 returns the int64 value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Int64(key string) (int64, error) { val := c.getData(key) if val != nil { @@ -126,6 +135,7 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { if v, err := c.Int64(key); err == nil { return v @@ -134,6 +144,7 @@ func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } // Float returns the float value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Float(key string) (float64, error) { val := c.getData(key) if val != nil { @@ -147,6 +158,7 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { if v, err := c.Float(key); err == nil { return v @@ -155,6 +167,7 @@ func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float } // String returns the string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) String(key string) string { val := c.getData(key) if val != nil { @@ -167,6 +180,7 @@ func (c *JSONConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { // TODO FIXME should not use "" to replace non existence if v := c.String(key); v != "" { @@ -176,6 +190,7 @@ func (c *JSONConfigContainer) DefaultString(key string, defaultval string) strin } // Strings returns the []string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Strings(key string) []string { stringVal := c.String(key) if stringVal == "" { @@ -186,6 +201,7 @@ func (c *JSONConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { if v := c.Strings(key); v != nil { return v @@ -194,6 +210,7 @@ func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) [] } // GetSection returns map for the given section +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil @@ -202,6 +219,7 @@ func (c *JSONConfigContainer) GetSection(section string) (map[string]string, err } // SaveConfigFile save the config into file +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) @@ -218,6 +236,7 @@ func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() @@ -226,6 +245,7 @@ func (c *JSONConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { val := c.getData(key) if val != nil { @@ -235,6 +255,7 @@ func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { } // section.key or key +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *JSONConfigContainer) getData(key string) interface{} { if len(key) == 0 { return nil diff --git a/config/xml/xml.go b/config/xml/xml.go index 494242d3..1601561f 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -46,9 +46,11 @@ import ( // Config is a xml config parser and implements Config interface. // xml configurations should be included in tag. // only support key/value pair as value as each item. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type Config struct{} // Parse returns a ConfigContainer with parsed xml config map. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (xc *Config) Parse(filename string) (config.Configer, error) { context, err := ioutil.ReadFile(filename) if err != nil { @@ -59,6 +61,7 @@ func (xc *Config) Parse(filename string) (config.Configer, error) { } // ParseData xml data +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (xc *Config) ParseData(data []byte) (config.Configer, error) { x := &ConfigContainer{data: make(map[string]interface{})} @@ -73,12 +76,14 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { } // ConfigContainer A Config represents the xml configuration. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type ConfigContainer struct { data map[string]interface{} sync.Mutex } // Bool returns the boolean value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Bool(key string) (bool, error) { if v := c.data[key]; v != nil { return config.ParseBool(v) @@ -88,6 +93,7 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) if err != nil { @@ -97,12 +103,14 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { } // Int returns the integer value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.data[key].(string)) } // DefaultInt returns the integer value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -112,12 +120,14 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { } // Int64 returns the int64 value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.data[key].(string), 10, 64) } // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -128,12 +138,14 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } // Float returns the float value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.data[key].(string), 64) } // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -143,6 +155,7 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { } // String returns the string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) String(key string) string { if v, ok := c.data[key].(string); ok { return v @@ -152,6 +165,7 @@ func (c *ConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -161,6 +175,7 @@ func (c *ConfigContainer) DefaultString(key string, defaultval string) string { } // Strings returns the []string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Strings(key string) []string { v := c.String(key) if v == "" { @@ -171,6 +186,7 @@ func (c *ConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -180,6 +196,7 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri } // GetSection returns map for the given section +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section].(map[string]interface{}); ok { mapstr := make(map[string]string) @@ -192,6 +209,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) @@ -208,6 +226,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() @@ -216,6 +235,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index a5644c7b..725f905b 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -45,9 +45,11 @@ import ( ) // Config is a yaml config parser and implements Config interface. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type Config struct{} // Parse returns a ConfigContainer with parsed yaml config map. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (yaml *Config) Parse(filename string) (y config.Configer, err error) { cnf, err := ReadYmlReader(filename) if err != nil { @@ -60,6 +62,7 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) { } // ParseData parse yaml data +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (yaml *Config) ParseData(data []byte) (config.Configer, error) { cnf, err := parseYML(data) if err != nil { @@ -73,6 +76,7 @@ func (yaml *Config) ParseData(data []byte) (config.Configer, error) { // ReadYmlReader Read yaml file to map. // if json like, use json package, unless goyaml2 package. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { buf, err := ioutil.ReadFile(path) if err != nil { @@ -117,12 +121,14 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { } // ConfigContainer A Config represents the yaml configuration. +// Deprecated: using pkg/config, we will delete this in v2.1.0 type ConfigContainer struct { data map[string]interface{} sync.RWMutex } // Bool returns the boolean value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Bool(key string) (bool, error) { v, err := c.getData(key) if err != nil { @@ -133,6 +139,7 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) if err != nil { @@ -142,6 +149,7 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { } // Int returns the integer value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Int(key string) (int, error) { if v, err := c.getData(key); err != nil { return 0, err @@ -155,6 +163,7 @@ func (c *ConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -164,6 +173,7 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { } // Int64 returns the int64 value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Int64(key string) (int64, error) { if v, err := c.getData(key); err != nil { return 0, err @@ -175,6 +185,7 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -184,6 +195,7 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } // Float returns the float value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Float(key string) (float64, error) { if v, err := c.getData(key); err != nil { return 0.0, err @@ -199,6 +211,7 @@ func (c *ConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -208,6 +221,7 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { } // String returns the string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) String(key string) string { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { @@ -219,6 +233,7 @@ func (c *ConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -228,6 +243,7 @@ func (c *ConfigContainer) DefaultString(key string, defaultval string) string { } // Strings returns the []string value for a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Strings(key string) []string { v := c.String(key) if v == "" { @@ -238,6 +254,7 @@ func (c *ConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -247,6 +264,7 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri } // GetSection returns map for the given section +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { @@ -256,6 +274,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) @@ -268,6 +287,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() @@ -276,6 +296,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. +// Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { return c.getData(key) } From dec98f004c2afba37b4ee0837e7d57c1639db6b6 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 9 Aug 2020 12:05:10 +0000 Subject: [PATCH 081/207] Support opentracing filter for Orm --- pkg/orm/filter.go | 4 ++ pkg/orm/filter/opentracing/filter.go | 59 +++++++++++++++++++++++ pkg/orm/filter/opentracing/filter_test.go | 43 +++++++++++++++++ pkg/orm/filter/prometheus/filter.go | 4 ++ pkg/web/filter/opentracing/filter.go | 31 ++++++------ 5 files changed, 127 insertions(+), 14 deletions(-) create mode 100644 pkg/orm/filter/opentracing/filter.go create mode 100644 pkg/orm/filter/opentracing/filter_test.go diff --git a/pkg/orm/filter.go b/pkg/orm/filter.go index 9676e4af..d04b8c42 100644 --- a/pkg/orm/filter.go +++ b/pkg/orm/filter.go @@ -18,8 +18,12 @@ import ( "context" ) +// FilterChain is used to build a Filter +// don't forget to call next(...) inside your Filter type FilterChain func(next Filter) Filter +// Filter's behavior is a little big strang. +// it's only be called when users call methods of Ormer type Filter func(ctx context.Context, inv *Invocation) var globalFilterChains = make([]FilterChain, 0, 4) diff --git a/pkg/orm/filter/opentracing/filter.go b/pkg/orm/filter/opentracing/filter.go new file mode 100644 index 00000000..a55ae6d2 --- /dev/null +++ b/pkg/orm/filter/opentracing/filter.go @@ -0,0 +1,59 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + + "github.com/opentracing/opentracing-go" + + "github.com/astaxie/beego/pkg/orm" +) + +// FilterChainBuilder provides an extension point +// this Filter's behavior looks a little bit strange +// for example: +// if we want to trace QuerySetter +// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +type FilterChainBuilder struct { + // CustomSpanFunc users are able to custom their span + CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) +} + +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) { + operationName := builder.operationName(ctx, inv) + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + + next(spanCtx, inv) + span.SetTag("Method", inv.Method) + span.SetTag("Table", inv.GetTableName()) + span.SetTag("InsideTx", inv.InsideTx) + span.SetTag("TxName", spanCtx.Value(orm.TxNameKey)) + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, spanCtx, inv) + } + + } +} + +func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string { + if n, ok := ctx.Value(orm.TxNameKey).(string); ok { + return inv.Method + "#" + n + } + return inv.Method + "#" + inv.GetTableName() +} diff --git a/pkg/orm/filter/opentracing/filter_test.go b/pkg/orm/filter/opentracing/filter_test.go new file mode 100644 index 00000000..1428df8a --- /dev/null +++ b/pkg/orm/filter/opentracing/filter_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "testing" + "time" + + "github.com/opentracing/opentracing-go" + + "github.com/astaxie/beego/pkg/orm" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + next := func(ctx context.Context, inv *orm.Invocation) { + inv.TxName = "Hello" + } + + builder := &FilterChainBuilder{ + CustomSpanFunc: func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) { + span.SetTag("hello", "hell") + }, + } + + inv := &orm.Invocation{ + Method: "Hello", + TxStartTime: time.Now(), + } + builder.FilterChain(next)(context.Background(), inv) +} \ No newline at end of file diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/orm/filter/prometheus/filter.go index 9f177deb..33fdf78f 100644 --- a/pkg/orm/filter/prometheus/filter.go +++ b/pkg/orm/filter/prometheus/filter.go @@ -29,6 +29,10 @@ import ( // FilterChainBuilder is an extension point, // when we want to support some configuration, // please use this structure +// this Filter's behavior looks a little bit strange +// for example: +// if we want to records the metrics of QuerySetter +// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" type FilterChainBuilder struct { summaryVec prometheus.ObserverVec } diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go index 8e332c7d..82d0f719 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/web/filter/opentracing/filter.go @@ -30,22 +30,14 @@ type FilterChainBuilder struct { func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { return func(ctx *context.Context) { - span := opentracing.SpanFromContext(ctx.Request.Context()) - spanCtx := ctx.Request.Context() - if span == nil { - operationName := ctx.Input.URL() - // it means that there is not any span, so we create a span as the root span. - // TODO, if we support multiple servers, this need to be changed - route, found := beego.BeeApp.Handlers.FindRouter(ctx) - if found { - operationName = route.GetPattern() - } - span, spanCtx = opentracing.StartSpanFromContext(spanCtx, operationName) - newReq := ctx.Request.Clone(spanCtx) - ctx.Reset(ctx.ResponseWriter.ResponseWriter, newReq) - } + operationName := builder.operationName(ctx) + span, spanCtx := opentracing.StartSpanFromContext(ctx.Request.Context(), operationName) defer span.Finish() + + newReq := ctx.Request.Clone(spanCtx) + ctx.Reset(ctx.ResponseWriter.ResponseWriter, newReq) + next(ctx) // if you think we need to do more things, feel free to create an issue to tell us span.SetTag("status", ctx.Output.Status) @@ -56,3 +48,14 @@ func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.Filt } } } + +func (builder *FilterChainBuilder) operationName(ctx *context.Context) string { + operationName := ctx.Input.URL() + // it means that there is not any span, so we create a span as the root span. + // TODO, if we support multiple servers, this need to be changed + route, found := beego.BeeApp.Handlers.FindRouter(ctx) + if found { + operationName = route.GetPattern() + } + return operationName +} From 2e891152dd21792d012bb88abc98e69784147498 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 9 Aug 2020 13:41:39 +0000 Subject: [PATCH 082/207] deprecated httplib and then support prometheus for httplib --- httplib/httplib.go | 43 ++++++++++++ pkg/httplib/filter.go | 24 +++++++ pkg/httplib/filter/prometheus/filter.go | 73 ++++++++++++++++++++ pkg/httplib/filter/prometheus/filter_test.go | 41 +++++++++++ pkg/httplib/httplib.go | 36 +++++++++- 5 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 pkg/httplib/filter.go create mode 100644 pkg/httplib/filter/prometheus/filter.go create mode 100644 pkg/httplib/filter/prometheus/filter_test.go diff --git a/httplib/httplib.go b/httplib/httplib.go index 60aa4e8b..8ae95641 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -74,6 +74,7 @@ func createDefaultCookie() { } // SetDefaultSetting Overwrite default settings +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func SetDefaultSetting(setting BeegoHTTPSettings) { settingMutex.Lock() defer settingMutex.Unlock() @@ -81,6 +82,7 @@ func SetDefaultSetting(setting BeegoHTTPSettings) { } // NewBeegoRequest return *BeegoHttpRequest with specific method +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) @@ -106,31 +108,37 @@ func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { } // Get returns *BeegoHttpRequest with GET method. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func Get(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "GET") } // Post returns *BeegoHttpRequest with POST method. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func Post(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "POST") } // Put returns *BeegoHttpRequest with PUT method. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func Put(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "PUT") } // Delete returns *BeegoHttpRequest DELETE method. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func Delete(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "DELETE") } // Head returns *BeegoHttpRequest with HEAD method. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } // BeegoHTTPSettings is the http.Client setting +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 type BeegoHTTPSettings struct { ShowDebug bool UserAgent string @@ -148,6 +156,7 @@ type BeegoHTTPSettings struct { } // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 type BeegoHTTPRequest struct { url string req *http.Request @@ -160,35 +169,41 @@ type BeegoHTTPRequest struct { } // GetRequest return the request object +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) GetRequest() *http.Request { return b.req } // Setting Change request settings +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { b.setting = setting return b } // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { b.req.SetBasicAuth(username, password) return b } // SetEnableCookie sets enable/disable cookiejar +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { b.setting.EnableCookie = enable return b } // SetUserAgent sets User-Agent header field +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { b.setting.UserAgent = useragent return b } // Debug sets show debug or not when executing request. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { b.setting.ShowDebug = isdebug return b @@ -198,28 +213,33 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { // default is 0 means no retried. // -1 means retried forever. // others means retried times. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { b.setting.Retries = times return b } +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { b.setting.RetryDelay = delay return b } // DumpBody setting whether need to Dump the Body. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump return b } // DumpRequest return the DumpRequest +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) DumpRequest() []byte { return b.dump } // SetTimeout sets connect time out and read-write time out for BeegoRequest. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.setting.ConnectTimeout = connectTimeout b.setting.ReadWriteTimeout = readWriteTimeout @@ -227,18 +247,21 @@ func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Dura } // SetTLSClientConfig sets tls connection configurations if visiting https url. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { b.setting.TLSClientConfig = config return b } // Header add header item string in request. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { b.req.Header.Set(key, value) return b } // SetHost set the request host +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { b.req.Host = host return b @@ -246,6 +269,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { // SetProtocolVersion Set the protocol version for incoming requests. // Client requests always use HTTP/1.1. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { if len(vers) == 0 { vers = "HTTP/1.1" @@ -262,12 +286,14 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { } // SetCookie add cookie into request. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { b.req.Header.Add("Cookie", cookie.String()) return b } // SetTransport set the setting transport +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { b.setting.Transport = transport return b @@ -280,6 +306,7 @@ func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPR // u, _ := url.ParseRequestURI("http://127.0.0.1:8118") // return u, nil // } +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { b.setting.Proxy = proxy return b @@ -289,6 +316,7 @@ func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) // // If CheckRedirect is nil, the Client uses its default policy, // which is to stop after 10 consecutive requests. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { b.setting.CheckRedirect = redirect return b @@ -296,6 +324,7 @@ func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { if param, ok := b.params[key]; ok { b.params[key] = append(param, value) @@ -306,6 +335,7 @@ func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { } // PostFile add a post file to the request +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { b.files[formname] = filename return b @@ -313,6 +343,7 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest // Body adds request raw body. // it supports string and []byte. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: @@ -328,6 +359,7 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { } // XMLBody adds request raw body encoding by XML. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) @@ -342,6 +374,7 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { } // YAMLBody adds request raw body encoding by YAML. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) @@ -356,6 +389,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) } // JSONBody adds request raw body encoding by JSON. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := json.Marshal(obj) @@ -438,6 +472,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { } // DoRequest will do the client.Do +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { var paramBody string if len(b.params) > 0 { @@ -531,6 +566,7 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { // String returns the body string in response. // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) String() (string, error) { data, err := b.Bytes() if err != nil { @@ -542,6 +578,7 @@ func (b *BeegoHTTPRequest) String() (string, error) { // Bytes returns the body []byte in response. // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.body != nil { return b.body, nil @@ -568,6 +605,7 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { // ToFile saves the body data in response to one file. // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) ToFile(filename string) error { resp, err := b.getResponse() if err != nil { @@ -608,6 +646,7 @@ func pathExistAndMkdir(filename string) (err error) { // ToJSON returns the map that marshals from the body bytes as json in response . // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -618,6 +657,7 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { // ToXML returns the map that marshals from the body bytes as xml in response . // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) ToXML(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -628,6 +668,7 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { // ToYAML returns the map that marshals from the body bytes as yaml in response . // it calls Response inner. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { data, err := b.Bytes() if err != nil { @@ -637,11 +678,13 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { } // Response executes request client gets response mannually. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +// Deprecated: using pkg/httplib, we will delete this in v2.1.0 func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { return func(netw, addr string) (net.Conn, error) { conn, err := net.DialTimeout(netw, addr, cTimeout) diff --git a/pkg/httplib/filter.go b/pkg/httplib/filter.go new file mode 100644 index 00000000..72a497d0 --- /dev/null +++ b/pkg/httplib/filter.go @@ -0,0 +1,24 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "context" + "net/http" +) + +type FilterChain func(next Filter) Filter + +type Filter func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) diff --git a/pkg/httplib/filter/prometheus/filter.go b/pkg/httplib/filter/prometheus/filter.go new file mode 100644 index 00000000..a0b24d67 --- /dev/null +++ b/pkg/httplib/filter/prometheus/filter.go @@ -0,0 +1,73 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "net/http" + "strconv" + "time" + + "github.com/prometheus/client_golang/prometheus" + + beego "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/httplib" +) + +type FilterChainBuilder struct { + summaryVec prometheus.ObserverVec +} + +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + + builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "remote_http_request", + ConstLabels: map[string]string{ + "server": beego.BConfig.ServerName, + "env": beego.BConfig.RunMode, + "appname": beego.BConfig.AppName, + }, + Help: "The statics info for remote http requests", + }, []string{"proto", "scheme", "method", "host", "path", "status", "duration", "isError"}) + + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + startTime := time.Now() + resp, err := next(ctx, req) + endTime := time.Now() + go builder.report(startTime, endTime, ctx, req, resp, err) + return resp, err + } +} + +func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time, + ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) { + + proto := req.GetRequest().Proto + + scheme := req.GetRequest().URL.Scheme + method := req.GetRequest().Method + + host := req.GetRequest().URL.Host + path := req.GetRequest().URL.Path + + status := resp.StatusCode + + dur := int(endTime.Sub(startTime) / time.Millisecond) + + + builder.summaryVec.WithLabelValues(proto, scheme, method, host, path, + strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil)) +} diff --git a/pkg/httplib/filter/prometheus/filter_test.go b/pkg/httplib/filter/prometheus/filter_test.go new file mode 100644 index 00000000..e15d82e5 --- /dev/null +++ b/pkg/httplib/filter/prometheus/filter_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/httplib" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, nil + } + builder := &FilterChainBuilder{} + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.Nil(t, err) +} diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go index 7255b2ca..f8ab80a1 100644 --- a/pkg/httplib/httplib.go +++ b/pkg/httplib/httplib.go @@ -34,6 +34,7 @@ package httplib import ( "bytes" "compress/gzip" + "context" "crypto/tls" "encoding/json" "encoding/xml" @@ -66,6 +67,11 @@ var defaultSetting = BeegoHTTPSettings{ var defaultCookieJar http.CookieJar var settingMutex sync.Mutex +// it will be the last filter and execute request.Do +var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return req.doRequest(ctx) +} + // createDefaultCookie creates a global cookiejar to store cookies. func createDefaultCookie() { settingMutex.Lock() @@ -145,6 +151,7 @@ type BeegoHTTPSettings struct { DumpBody bool Retries int // if set to -1 means will retry forever RetryDelay time.Duration + FilterChains []FilterChain } // BeegoHTTPRequest provides more useful methods than http.Request for requesting a url. @@ -295,6 +302,18 @@ func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via return b } +// SetFilters will use the filter as the invocation filters +func (b *BeegoHTTPRequest) SetFilters(fcs ...FilterChain) *BeegoHTTPRequest { + b.setting.FilterChains = fcs + return b +} + +// AddFilters adds filter +func (b *BeegoHTTPRequest) AddFilters(fcs ...FilterChain) *BeegoHTTPRequest { + b.setting.FilterChains = append(b.setting.FilterChains, fcs...) + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { @@ -397,7 +416,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { if err != nil { log.Println("Httplib:", err) } - //iocopy + // iocopy _, err = io.Copy(fileWriter, fh) fh.Close() if err != nil { @@ -440,6 +459,21 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { // DoRequest executes client.Do func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + return b.DoRequestWithCtx(context.Background()) +} + +func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) { + + root := doRequestFilter + if len(b.setting.FilterChains) > 0 { + for i := len(b.setting.FilterChains) - 1; i >= 0; i-- { + root = b.setting.FilterChains[i](root) + } + } + return root(ctx, b) +} + +func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) { var paramBody string if len(b.params) > 0 { var buf bytes.Buffer From 75107f735ee8e6a15e03657c32c3ada4ce63585c Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 9 Aug 2020 14:59:08 +0000 Subject: [PATCH 083/207] Support opentracing filter --- pkg/httplib/filter/opentracing/filter.go | 77 +++++++++++++++++++ pkg/httplib/filter/opentracing/filter_test.go | 42 ++++++++++ pkg/web/filter/opentracing/filter.go | 25 ++++-- 3 files changed, 139 insertions(+), 5 deletions(-) create mode 100644 pkg/httplib/filter/opentracing/filter.go create mode 100644 pkg/httplib/filter/opentracing/filter_test.go diff --git a/pkg/httplib/filter/opentracing/filter.go b/pkg/httplib/filter/opentracing/filter.go new file mode 100644 index 00000000..5f409c63 --- /dev/null +++ b/pkg/httplib/filter/opentracing/filter.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "net/http" + "strconv" + + logKit "github.com/go-kit/kit/log" + opentracingKit "github.com/go-kit/kit/tracing/opentracing" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" + + "github.com/astaxie/beego/pkg/httplib" +) + +type FilterChainBuilder struct { + // CustomSpanFunc users are able to custom their span + CustomSpanFunc func(span opentracing.Span, ctx context.Context, + req *httplib.BeegoHTTPRequest, resp *http.Response, err error) +} + +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + + method := req.GetRequest().Method + host := req.GetRequest().URL.Host + path := req.GetRequest().URL.Path + + proto := req.GetRequest().Proto + + scheme := req.GetRequest().URL.Scheme + + operationName := host + path + "#" + method + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + + inject := opentracingKit.ContextToHTTP(opentracing.GlobalTracer(), logKit.NewNopLogger()) + inject(spanCtx, req.GetRequest()) + resp, err := next(spanCtx, req) + + if resp != nil { + span.SetTag("status", strconv.Itoa(resp.StatusCode)) + } + + span.SetTag("method", method) + span.SetTag("host", host) + span.SetTag("path", path) + span.SetTag("proto", proto) + span.SetTag("scheme", scheme) + + span.LogFields(log.String("url", req.GetRequest().URL.String())) + + if err != nil { + span.LogFields(log.String("error", err.Error())) + } + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx, req, resp, err) + } + return resp, err + } +} diff --git a/pkg/httplib/filter/opentracing/filter_test.go b/pkg/httplib/filter/opentracing/filter_test.go new file mode 100644 index 00000000..aa687541 --- /dev/null +++ b/pkg/httplib/filter/opentracing/filter_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/httplib" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, errors.New("hello") + } + builder := &FilterChainBuilder{} + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.NotNil(t, err) +} diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go index 82d0f719..822d5e4d 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/web/filter/opentracing/filter.go @@ -15,24 +15,39 @@ package opentracing import ( + "context" + + logKit "github.com/go-kit/kit/log" + opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + beegoCtx "github.com/astaxie/beego/pkg/context" ) // FilterChainBuilder provides an extension point that we can support more configurations if necessary type FilterChainBuilder struct { // CustomSpanFunc makes users to custom the span. - CustomSpanFunc func(span opentracing.Span, ctx *context.Context) + CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) } func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { - return func(ctx *context.Context) { + return func(ctx *beegoCtx.Context) { + var ( + spanCtx context.Context + span opentracing.Span + ) operationName := builder.operationName(ctx) - span, spanCtx := opentracing.StartSpanFromContext(ctx.Request.Context(), operationName) + if preSpan := opentracing.SpanFromContext(ctx.Request.Context()); preSpan == nil { + inject := opentracingKit.HTTPToContext(opentracing.GlobalTracer(), operationName, logKit.NewNopLogger()) + spanCtx = inject(ctx.Request.Context(), ctx.Request) + span = opentracing.SpanFromContext(spanCtx) + } else { + span, spanCtx = opentracing.StartSpanFromContext(ctx.Request.Context(), operationName) + } + defer span.Finish() newReq := ctx.Request.Clone(spanCtx) @@ -49,7 +64,7 @@ func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.Filt } } -func (builder *FilterChainBuilder) operationName(ctx *context.Context) string { +func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string { operationName := ctx.Input.URL() // it means that there is not any span, so we create a span as the root span. // TODO, if we support multiple servers, this need to be changed From 5a1fa4e1ec36e874b48e29afa9ab03eee39c2d80 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 10 Aug 2020 18:46:16 +0800 Subject: [PATCH 084/207] specify index --- pkg/common/kv.go | 38 +++++-- pkg/orm/db.go | 64 +++++++++-- pkg/orm/db_alias.go | 13 +-- pkg/orm/db_alias_test.go | 15 +-- pkg/orm/db_hints_test.go | 76 ------------- pkg/orm/db_oracle.go | 24 +++++ pkg/orm/db_postgres.go | 7 ++ pkg/orm/db_sqlite.go | 21 ++++ pkg/orm/db_tables.go | 9 ++ pkg/orm/do_nothing_orm.go | 5 +- pkg/orm/filter_orm_decorator.go | 5 +- pkg/orm/filter_orm_decorator_test.go | 3 +- pkg/orm/{ => hints}/db_hints.go | 80 +++++++++++--- pkg/orm/hints/db_hints_test.go | 154 +++++++++++++++++++++++++++ pkg/orm/models_test.go | 3 +- pkg/orm/orm.go | 44 ++++---- pkg/orm/orm_log.go | 25 +---- pkg/orm/orm_queryset.go | 28 ++++- pkg/orm/orm_test.go | 27 +++-- pkg/orm/types.go | 47 +++++--- 20 files changed, 499 insertions(+), 189 deletions(-) delete mode 100644 pkg/orm/db_hints_test.go rename pkg/orm/{ => hints}/db_hints.go (50%) create mode 100644 pkg/orm/hints/db_hints_test.go diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 8468f4fe..26e786f9 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -36,14 +36,25 @@ func (s *SimpleKV) GetValue() interface{} { return s.Value } -// KVs will store SimpleKV collection as map -type KVs struct { +// KVs interface +type KVs interface { + GetValueOr(key interface{}, defValue interface{}) interface{} + Contains(key interface{}) bool + IfContains(key interface{}, action func(value interface{})) KVs + Put(key interface{}, value interface{}) KVs + Clone() KVs +} + +// SimpleKVs will store SimpleKV collection as map +type SimpleKVs struct { kvs map[interface{}]interface{} } +var _ KVs = new(SimpleKVs) + // GetValueOr returns the value for a given key, if non-existant // it returns defValue -func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { +func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} { v, ok := kvs.kvs[key] if ok { return v @@ -52,13 +63,13 @@ func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { } // Contains checks if a key exists -func (kvs *KVs) Contains(key interface{}) bool { +func (kvs *SimpleKVs) Contains(key interface{}) bool { _, ok := kvs.kvs[key] return ok } // IfContains invokes the action on a key if it exists -func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { +func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs { v, ok := kvs.kvs[key] if ok { action(v) @@ -67,14 +78,25 @@ func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs } // Put stores the value -func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { +func (kvs *SimpleKVs) Put(key interface{}, value interface{}) KVs { kvs.kvs[key] = value return kvs } +// Clone +func (kvs *SimpleKVs) Clone() KVs { + newKVs := new(SimpleKVs) + + for key, value := range kvs.kvs { + newKVs.Put(key, value) + } + + return newKVs +} + // NewKVs creates the *KVs instance -func NewKVs(kvs ...KV) *KVs { - res := &KVs{ +func NewKVs(kvs ...KV) KVs { + res := &SimpleKVs{ kvs: make(map[interface{}]interface{}, len(kvs)), } for _, kv := range kvs { diff --git a/pkg/orm/db.go b/pkg/orm/db.go index 9a1827e8..573247f0 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -18,6 +18,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "reflect" "strings" "time" @@ -738,8 +739,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } tables := newDbTables(mi, d.ins) + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) } where, args := tables.getCondSQL(cond, false, tz) @@ -790,9 +793,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con sets := strings.Join(cols, ", ") + " " if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) + query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where) } else { - supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s", + Q, mi.fields.pk.column, Q, + Q, mi.table, Q, + specifyIndexes, join, where) query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) } @@ -843,8 +849,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con tables := newDbTables(mi, d.ins) tables.skipEnd = true + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) } if cond == nil || cond.IsEmpty() { @@ -857,7 +865,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con join := tables.getJoinSQL() cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where) d.ins.ReplaceMarks(&query) @@ -1002,6 +1010,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, offset, rlimit) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) for _, tbl := range tables.tables { if tbl.sel { @@ -1015,9 +1024,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", + sqlSelect, sels, Q, mi.table, Q, + specifyIndexes, join, where, groupBy, orderBy, limit) - if qs.forupdate { + if qs.forUpdate { query += " FOR UPDATE" } @@ -1153,10 +1164,13 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition groupBy := tables.getGroupSQL(qs.groups) tables.getOrderSQL(qs.orders) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) Q := d.ins.TableQuote() - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", + Q, mi.table, Q, + specifyIndexes, join, where, groupBy) if groupBy != "" { query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) @@ -1680,6 +1694,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, qs.offset, qs.limit) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) sels := strings.Join(cols, ", ") @@ -1687,7 +1702,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond if qs.distinct { sqlSelect += " DISTINCT" } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", + sqlSelect, sels, + Q, mi.table, Q, + specifyIndexes, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -1781,10 +1799,6 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } -func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { - return 0, nil -} - // flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true @@ -1900,3 +1914,31 @@ func (d *dbBase) ShowColumnsQuery(table string) string { func (d *dbBase) IndexExists(dbQuerier, string, string) bool { panic(ErrNotImplement) } + +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var useWay string + + switch useIndex { + case hints.KeyUseIndex: + useWay = `USE` + case hints.KeyForceIndex: + useWay = `FORCE` + case hints.KeyIgnoreIndex: + useWay = `IGNORE` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`)) +} + + diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 5f1e3ea3..93f282af 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -363,7 +364,7 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) + maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) if maxStmtCacheSize > 0 { _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) if errC != nil { @@ -398,15 +399,15 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K detectTZ(al) - kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { if m, ok := value.(int); ok { SetMaxIdleConns(al, m) } - }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { if m, ok := value.(int); ok { SetMaxOpenConns(al, m) } - }).IfContains(connMaxLifetimeKey, func(value interface{}) { + }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { if m, ok := value.(time.Duration); ok { SetConnMaxLifetime(al, m) } @@ -422,7 +423,7 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { var ( err error db *sql.DB @@ -436,7 +437,7 @@ func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common. goto end } - al, err = addAliasWthDB(aliasName, driverName, db, hints...) + al, err = addAliasWthDB(aliasName, driverName, db, params...) if err != nil { goto end } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 111657d7..576214fc 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -15,6 +15,7 @@ package orm import ( + "github.com/astaxie/beego/pkg/orm/hints" "testing" "time" @@ -23,9 +24,9 @@ import ( func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - MaxIdleConnections(20), - MaxOpenConnections(300), - ConnMaxLifetime(time.Minute)) + hints.MaxIdleConnections(20), + hints.MaxOpenConnections(300), + hints.ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -37,7 +38,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -47,7 +48,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -57,7 +58,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -67,7 +68,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go deleted file mode 100644 index 13f8ccde..00000000 --- a/pkg/orm/db_hints_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestNewHint_time(t *testing.T) { - key := "qweqwe" - value := time.Second - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestNewHint_int(t *testing.T) { - key := "qweqwe" - value := 281230 - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestNewHint_float(t *testing.T) { - key := "qweqwe" - value := 21.2459753 - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), connMaxLifetimeKey) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) -} \ No newline at end of file diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index 5d121f83..fa49e16b 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -16,6 +16,7 @@ package orm import ( "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "strings" ) @@ -96,6 +97,29 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var hint string + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + hint = `INDEX` + case hints.KeyIgnoreIndex: + hint = `NO_INDEX` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`)) +} + // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go index c488fb38..cf1a3413 100644 --- a/pkg/orm/db_postgres.go +++ b/pkg/orm/db_postgres.go @@ -92,6 +92,7 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } + // postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` @@ -181,6 +182,12 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo return cnt > 0 } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") + return `` +} + // create new postgresql dbBaser. func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go index 1d62ee34..244aae7a 100644 --- a/pkg/orm/db_sqlite.go +++ b/pkg/orm/db_sqlite.go @@ -17,7 +17,9 @@ package orm import ( "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "reflect" + "strings" "time" ) @@ -153,6 +155,25 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool return false } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`)) + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } +} + + // create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go index 4b21a6fc..d7e99639 100644 --- a/pkg/orm/db_tables.go +++ b/pkg/orm/db_tables.go @@ -472,6 +472,15 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits return } +// getIndexSql generate index sql. +func (t *dbTables) getIndexSql(tableName string,useIndex int, indexes []string) (clause string) { + if len(indexes) == 0 { + return + } + + return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes) +} + // crete new tables collection. func newDbTables(mi *modelInfo, base dbBaser) *dbTables { tables := &dbTables{} diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go index 87b0a2ae..686b7752 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/orm/do_nothing_orm.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" ) // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation @@ -52,11 +53,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, return false, 0, nil } -func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } -func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go index eb26ea68..2f32d8c6 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/orm/filter_orm_decorator.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" "reflect" "time" ) @@ -133,11 +134,11 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa return ok, res, err } -func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return f.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { var ( res int64 err error diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index d1099eaf..abb8322c 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "errors" + "github.com/astaxie/beego/pkg/common" "sync" "testing" @@ -360,7 +361,7 @@ func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{} return errors.New("read for update error") } -func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { return 99, errors.New("load related error") } diff --git a/pkg/orm/db_hints.go b/pkg/orm/hints/db_hints.go similarity index 50% rename from pkg/orm/db_hints.go rename to pkg/orm/hints/db_hints.go index 551c7357..f708f310 100644 --- a/pkg/orm/db_hints.go +++ b/pkg/orm/hints/db_hints.go @@ -12,13 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package hints import ( "github.com/astaxie/beego/pkg/common" "time" ) +const ( + //db level + KeyMaxIdleConnections = iota + KeyMaxOpenConnections + KeyConnMaxLifetime + KeyMaxStmtCacheSize + + //query level + KeyForceIndex + KeyUseIndex + KeyIgnoreIndex + KeyForUpdate + KeyLimit + KeyOffset + KeyOrderBy + KeyRelDepth +) + type Hint struct { key interface{} value interface{} @@ -36,33 +54,71 @@ func (s *Hint) GetValue() interface{} { return s.value } -const ( - maxIdleConnectionsKey = "MaxIdleConnections" - maxOpenConnectionsKey = "MaxOpenConnections" - connMaxLifetimeKey = "ConnMaxLifetime" - maxStmtCacheSizeKey = "MaxStmtCacheSize" -) - var _ common.KV = new(Hint) // MaxIdleConnections return a hint about MaxIdleConnections func MaxIdleConnections(v int) *Hint { - return NewHint(maxIdleConnectionsKey, v) + return NewHint(KeyMaxIdleConnections, v) } // MaxOpenConnections return a hint about MaxOpenConnections func MaxOpenConnections(v int) *Hint { - return NewHint(maxOpenConnectionsKey, v) + return NewHint(KeyMaxOpenConnections, v) } // ConnMaxLifetime return a hint about ConnMaxLifetime func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(connMaxLifetimeKey, v) + return NewHint(KeyConnMaxLifetime, v) } // MaxStmtCacheSize return a hint about MaxStmtCacheSize func MaxStmtCacheSize(v int) *Hint { - return NewHint(maxStmtCacheSizeKey, v) + return NewHint(KeyMaxStmtCacheSize, v) +} + +// ForceIndex return a hint about ForceIndex +func ForceIndex(indexes ...string) *Hint { + return NewHint(KeyForceIndex, indexes) +} + +// UseIndex return a hint about UseIndex +func UseIndex(indexes ...string) *Hint { + return NewHint(KeyUseIndex, indexes) +} + +// IgnoreIndex return a hint about IgnoreIndex +func IgnoreIndex(indexes ...string) *Hint { + return NewHint(KeyIgnoreIndex, indexes) +} + +// ForUpdate return a hint about ForUpdate +func ForUpdate() *Hint { + return NewHint(KeyForUpdate, true) +} + +// DefaultRelDepth return a hint about DefaultRelDepth +func DefaultRelDepth() *Hint { + return NewHint(KeyRelDepth, true) +} + +// RelDepth return a hint about RelDepth +func RelDepth(d int) *Hint { + return NewHint(KeyRelDepth, d) +} + +// Limit return a hint about Limit +func Limit(d int64) *Hint { + return NewHint(KeyLimit, d) +} + +// Offset return a hint about Offset +func Offset(d int64) *Hint { + return NewHint(KeyOffset, d) +} + +// OrderBy return a hint about OrderBy +func OrderBy(s string) *Hint { + return NewHint(KeyOrderBy, s) } // NewHint return a hint diff --git a/pkg/orm/hints/db_hints_test.go b/pkg/orm/hints/db_hints_test.go new file mode 100644 index 00000000..5ab44b08 --- /dev/null +++ b/pkg/orm/hints/db_hints_test.go @@ -0,0 +1,154 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hints + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewHint_time(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_int(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_float(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestMaxOpenConnections(t *testing.T) { + i := 887423 + hint := MaxOpenConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) +} + +func TestConnMaxLifetime(t *testing.T) { + i := time.Hour + hint := ConnMaxLifetime(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) +} + +func TestMaxIdleConnections(t *testing.T) { + i := 42316 + hint := MaxIdleConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) +} + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) +} + +func TestForceIndex(t *testing.T) { + s := []string{`f_index1`, `f_index2`, `f_index3`} + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestForceIndex_0(t *testing.T) { + var s []string + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestIgnoreIndex(t *testing.T) { + s := []string{`i_index1`, `i_index2`, `i_index3`} + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestIgnoreIndex_0(t *testing.T) { + var s []string + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestUseIndex(t *testing.T) { + s := []string{`u_index1`, `u_index2`, `u_index3`} + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestUseIndex_0(t *testing.T) { + var s []string + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestForUpdate(t *testing.T) { + hint := ForUpdate() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyForUpdate) +} + +func TestDefaultRelDepth(t *testing.T) { + hint := DefaultRelDepth() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestRelDepth(t *testing.T) { + hint := RelDepth(157965) + assert.Equal(t, hint.GetValue(), 157965) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestLimit(t *testing.T) { + hint := Limit(1579625) + assert.Equal(t, hint.GetValue(), int64(1579625)) + assert.Equal(t, hint.GetKey(), KeyLimit) +} + +func TestOffset(t *testing.T) { + hint := Offset(int64(1572123965)) + assert.Equal(t, hint.GetValue(), int64(1572123965)) + assert.Equal(t, hint.GetKey(), KeyOffset) +} + +func TestOrderBy(t *testing.T) { + hint := OrderBy(`-ID`) + assert.Equal(t, hint.GetValue(), `-ID`) + assert.Equal(t, hint.GetKey(), KeyOrderBy) +} \ No newline at end of file diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index ae166dc7..935c2073 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -18,6 +18,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "os" "strings" "time" @@ -488,7 +489,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index d79053af..fb63d4e5 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -59,6 +59,7 @@ import ( "errors" "fmt" "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/orm/hints" "os" "reflect" "time" @@ -99,6 +100,7 @@ type ormBase struct { var _ DQL = new(ormBase) var _ DML = new(ormBase) +var _ DriverGetter = new(ormBase) // get model info and model reflect value func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { @@ -302,11 +304,10 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } - -func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) @@ -314,24 +315,29 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s var relDepth int var limit, offset int64 var order string - for i, arg := range args { - switch i { - case 0: - if v, ok := arg.(bool); ok { - if v { - relDepth = DefaultRelsDepth - } - } else if v, ok := arg.(int); ok { - relDepth = v + + kvs := common.NewKVs(args...) + kvs.IfContains(hints.KeyRelDepth, func(value interface{}) { + if v, ok := value.(bool); ok { + if v { + relDepth = DefaultRelsDepth } - case 1: - limit = ToInt64(arg) - case 2: - offset = ToInt64(arg) - case 3: - order, _ = arg.(string) + } else if v, ok := value.(int); ok { + relDepth = v } - } + }).IfContains(hints.KeyLimit, func(value interface{}) { + if v, ok := value.(int64); ok { + limit = v + } + }).IfContains(hints.KeyOffset, func(value interface{}) { + if v, ok := value.(int64); ok { + offset = v + } + }).IfContains(hints.KeyOrderBy, func(value interface{}) { + if v, ok := value.(string); ok { + order = v + } + }) switch fi.fieldType { case RelOneToOne, RelForeignKey, RelReverseOne: diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go index 5bb3a24f..d8df7e36 100644 --- a/pkg/orm/orm_log.go +++ b/pkg/orm/orm_log.go @@ -127,10 +127,7 @@ var _ txer = new(dbQueryLog) var _ txEnder = new(dbQueryLog) func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.Prepare(query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err + return d.PrepareContext(context.Background(), query) } func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { @@ -141,10 +138,7 @@ func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stm } func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.Exec(query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err + return d.ExecContext(context.Background(), query, args...) } func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { @@ -155,10 +149,7 @@ func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...inte } func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.Query(query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err + return d.QueryContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -169,10 +160,7 @@ func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...int } func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRow(query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res + return d.QueryRowContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { @@ -183,10 +171,7 @@ func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ... } func (d *dbQueryLog) Begin() (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).Begin() - debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) - return tx, err + return d.BeginTx(context.Background(), nil) } func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 83168de7..734fc738 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -17,6 +17,7 @@ package orm import ( "context" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" ) type colValue struct { @@ -71,7 +72,9 @@ type querySet struct { groups []string orders []string distinct bool - forupdate bool + forUpdate bool + useIndex int + indexes []string orm *ormBase ctx context.Context forContext bool @@ -148,7 +151,28 @@ func (o querySet) Distinct() QuerySeter { // add FOR UPDATE to SELECT func (o querySet) ForUpdate() QuerySeter { - o.forupdate = true + o.forUpdate = true + return &o +} + +// ForceIndex force index for query +func (o querySet) ForceIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyForceIndex + o.indexes = indexes + return &o +} + +// UseIndex use index for query +func (o querySet) UseIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyUseIndex + o.indexes = indexes + return &o +} + +// IgnoreIndex ignore index for query +func (o querySet) IgnoreIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyIgnoreIndex + o.indexes = indexes return &o } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index f5242a46..1d173426 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "io/ioutil" "math" "os" @@ -1279,24 +1280,32 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - num, err = dORM.LoadRelated(&user, "Posts", true) + num, err = dORM.LoadRelated(&user, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1) + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1)) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.OrderBy("-Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1), + hints.Offset(1), + hints.OrderBy("Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) @@ -1318,7 +1327,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&profile, "User", true) + num, err = dORM.LoadRelated(&profile, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(profile.User == nil, false)) @@ -1335,7 +1344,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile.Age, 30)) - num, err = dORM.LoadRelated(&user, "Profile", true) + num, err = dORM.LoadRelated(&user, "Profile", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(user.Profile == nil, false)) @@ -1355,7 +1364,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&post, "User", true) + num, err = dORM.LoadRelated(&post, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(post.User == nil, false)) @@ -1375,7 +1384,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - num, err = dORM.LoadRelated(&post, "Tags", true) + num, err = dORM.LoadRelated(&post, "Tags", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(post.Tags), 2)) @@ -1396,7 +1405,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - num, err = dORM.LoadRelated(&tag, "Posts", true) + num, err = dORM.LoadRelated(&tag, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 9624fd94..0be2b809 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" "reflect" "time" ) @@ -175,14 +176,14 @@ type DQL interface { // example: // Ormer.LoadRelated(post,"Tags") // for _,tag := range post.Tags{...} - // args[0] bool true useDefaultRelsDepth ; false depth 0 - // args[0] int loadRelationDepth - // args[1] int limit default limit 1000 - // args[2] int offset default offset 0 - // args[3] string order for example : "-Id" + // hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0 + // hints.RelDepth loadRelationDepth + // hints.Limit limit default limit 1000 + // hints.Offset int offset default offset 0 + // hints.OrderBy string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) - LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) // create a models to models queryer // for example: @@ -282,6 +283,21 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add FORCE INDEX expression. + // for example: + // qs.ForceIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + ForceIndex(indexes ...string) QuerySeter + // add USE INDEX expression. + // for example: + // qs.UseIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + UseIndex(indexes ...string) QuerySeter + // add IGNORE INDEX expression. + // for example: + // qs.IgnoreIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + IgnoreIndex(indexes ...string) QuerySeter // set relation model to query together. // it will query relation models and assign to parent model. // for example: @@ -527,24 +543,27 @@ type txEnder interface { // base database struct type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - SupportUpdateJoin() bool UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + + SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -559,4 +578,6 @@ type dbBaser interface { IndexExists(dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) setval(dbQuerier, *modelInfo, []string) error + + GenerateSpecifyIndex(tableName string,useIndex int ,indexes []string) string } From 882f1273c8e9afc26984802cc42a027358087a0d Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 10 Aug 2020 23:27:03 +0800 Subject: [PATCH 085/207] add UT for specifying indexes --- pkg/orm/models_test.go | 9 +++++++++ pkg/orm/orm_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 935c2073..52524501 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -382,6 +382,15 @@ type InLine struct { Email string } +type Index struct { + // Common Fields + Id int `orm:"column(id)"` + + // Other Fields + F1 int `orm:"column(f1);unique"` + F2 int `orm:"column(f2);unique"` +} + func NewInLine() *InLine { return new(InLine) } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 1d173426..58447adb 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -201,6 +201,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(Index)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -225,6 +226,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(Index)) BootStrap() @@ -794,6 +796,32 @@ func TestExpr(t *testing.T) { // throwFail(t, AssertIs(num, 3)) } +func TestSpecifyIndex(t *testing.T) { + var index *Index + index = &Index{ + F1: 1, + F2: 2, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 1)) + + index = &Index{ + F1: 3, + F2: 4, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`f1`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `3`).UseIndex(`f1`, `f2`).One(index) + throwFailNow(t, AssertIs(index.F2, 4)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`f1`, `f2`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) +} + func TestOperators(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.Filter("user_name", "slene").Count() From f8c0e6fec56100a0290d0fd51427e97ef99781fc Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Tue, 11 Aug 2020 00:06:36 +0800 Subject: [PATCH 086/207] fix UT --- pkg/orm/models_test.go | 4 ++-- pkg/orm/orm_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 52524501..85815edd 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -387,8 +387,8 @@ type Index struct { Id int `orm:"column(id)"` // Other Fields - F1 int `orm:"column(f1);unique"` - F2 int `orm:"column(f2);unique"` + F1 int `orm:"column(f1);index"` + F2 int `orm:"column(f2);index"` } func NewInLine() *InLine { diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 58447adb..e08b1b12 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -812,13 +812,13 @@ func TestSpecifyIndex(t *testing.T) { _, _ = dORM.Insert(index) throwFailNow(t, AssertIs(index.Id, 2)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`f1`).One(index) + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`index_f1`).One(index) throwFailNow(t, AssertIs(index.F2, 2)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `3`).UseIndex(`f1`, `f2`).One(index) - throwFailNow(t, AssertIs(index.F2, 4)) + _ = dORM.QueryTable(&Index{}).Filter(`f2`, `4`).UseIndex(`index_f2`).One(index) + throwFailNow(t, AssertIs(index.F1, 3)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`f1`, `f2`).One(index) + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`index_f1`, `index_f2`).One(index) throwFailNow(t, AssertIs(index.F2, 2)) } From c22af4c61199ed1e6c664fbf3ff9919201d2c6fa Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 10 Aug 2020 23:04:57 +0800 Subject: [PATCH 087/207] Fix Tracing and prometheus bug --- build_info.go | 8 +- config.go | 4 +- config/config.go | 6 +- config/fake.go | 16 +++ context/output.go | 1 - go.sum | 4 + pkg/admin_test.go | 18 +++- pkg/app.go | 1 + pkg/filter.go | 3 + pkg/filter_chain_test.go | 3 +- pkg/hooks.go | 2 +- pkg/httplib/filter.go | 2 +- pkg/httplib/filter/opentracing/filter.go | 40 ++++---- pkg/httplib/filter/opentracing/filter_test.go | 2 +- pkg/httplib/filter/prometheus/filter.go | 8 +- pkg/httplib/filter/prometheus/filter_test.go | 2 +- pkg/orm/db_alias.go | 3 +- pkg/orm/db_alias_test.go | 1 - pkg/orm/db_hints_test.go | 2 +- pkg/orm/do_nothing_orm.go | 11 ++- ...ing_omr_test.go => do_nothing_orm_test.go} | 2 +- pkg/orm/filter.go | 10 +- pkg/orm/filter/opentracing/filter.go | 38 ++++--- pkg/orm/filter/opentracing/filter_test.go | 4 +- pkg/orm/filter/prometheus/filter.go | 2 +- pkg/orm/filter/prometheus/filter_test.go | 2 +- pkg/orm/filter_orm_decorator.go | 98 ++++++++++--------- pkg/orm/filter_orm_decorator_test.go | 50 +++++----- pkg/orm/filter_test.go | 3 +- pkg/orm/invocation.go | 19 ++-- pkg/orm/model_utils_test.go | 2 +- pkg/orm/models_test.go | 1 - pkg/orm/orm.go | 38 +++---- pkg/orm/orm_test.go | 2 - pkg/orm/types.go | 8 +- pkg/router.go | 5 +- pkg/session/sess_file_test.go | 2 +- pkg/web/doc.go | 2 +- pkg/web/filter/opentracing/filter.go | 23 +++-- pkg/web/filter/opentracing/filter_test.go | 2 +- pkg/web/filter/prometheus/filter.go | 2 +- pkg/web/filter/prometheus/filter_test.go | 2 +- 42 files changed, 257 insertions(+), 197 deletions(-) rename pkg/orm/{do_nothing_omr_test.go => do_nothing_orm_test.go} (99%) diff --git a/build_info.go b/build_info.go index 896bbdf3..59e78127 100644 --- a/build_info.go +++ b/build_info.go @@ -16,15 +16,15 @@ package beego var ( // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildVersion string + BuildVersion string // Deprecated: using pkg/, we will delete this in v2.1.0 BuildGitRevision string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildStatus string + BuildStatus string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTag string + BuildTag string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTime string + BuildTime string // Deprecated: using pkg/, we will delete this in v2.1.0 GoVersion string diff --git a/config.go b/config.go index d707542a..7917528e 100644 --- a/config.go +++ b/config.go @@ -15,13 +15,13 @@ package beego import ( + "crypto/tls" "fmt" "os" "path/filepath" "reflect" "runtime" "strings" - "crypto/tls" "github.com/astaxie/beego/config" "github.com/astaxie/beego/context" @@ -163,7 +163,7 @@ func init() { } appConfigPath = filepath.Join(WorkPath, "conf", filename) if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { - appConfigPath = configPath + appConfigPath = configPath } if !utils.FileExists(appConfigPath) { appConfigPath = filepath.Join(AppPath, "conf", filename) diff --git a/config/config.go b/config/config.go index f46f862b..db2e96f6 100644 --- a/config/config.go +++ b/config/config.go @@ -51,9 +51,9 @@ import ( // Deprecated: using pkg/config, we will delete this in v2.1.0 type Configer interface { // Deprecated: using pkg/config, we will delete this in v2.1.0 - Set(key, val string) error //support section::key type in given key when using ini type. + Set(key, val string) error //support section::key type in given key when using ini type. // Deprecated: using pkg/config, we will delete this in v2.1.0 - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. // Deprecated: using pkg/config, we will delete this in v2.1.0 Strings(key string) []string //get string slice // Deprecated: using pkg/config, we will delete this in v2.1.0 @@ -65,7 +65,7 @@ type Configer interface { // Deprecated: using pkg/config, we will delete this in v2.1.0 Float(key string) (float64, error) // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultStrings(key string, defaultVal []string) []string //get string slice // Deprecated: using pkg/config, we will delete this in v2.1.0 diff --git a/config/fake.go b/config/fake.go index 07e56ce2..8093ad61 100644 --- a/config/fake.go +++ b/config/fake.go @@ -27,15 +27,18 @@ type fakeConfigContainer struct { func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Set(key, val string) error { c.data[strings.ToLower(key)] = val return nil } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) String(key string) string { return c.getData(key) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) @@ -44,6 +47,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Strings(key string) []string { v := c.String(key) @@ -52,6 +56,7 @@ func (c *fakeConfigContainer) Strings(key string) []string { } return strings.Split(v, ";") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) @@ -60,10 +65,12 @@ func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) [] } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) @@ -72,10 +79,12 @@ func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) @@ -84,10 +93,12 @@ func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getData(key)) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) @@ -96,10 +107,12 @@ func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) @@ -108,6 +121,7 @@ func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { @@ -115,10 +129,12 @@ func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { } return nil, errors.New("key not find") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") diff --git a/context/output.go b/context/output.go index eaa75720..7409e4e5 100644 --- a/context/output.go +++ b/context/output.go @@ -58,7 +58,6 @@ func (output *BeegoOutput) Clear() { output.Status = 0 } - // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) diff --git a/go.sum b/go.sum index 12b76333..75247943 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,6 +220,9 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/pkg/admin_test.go b/pkg/admin_test.go index e7eae771..5094aeed 100644 --- a/pkg/admin_test.go +++ b/pkg/admin_test.go @@ -6,10 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/toolbox" ) @@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { t.Errorf("invalid response map length: got %d want %d", len(decodedResponseBody), len(expectedResponseBody)) } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) - if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { - t.Errorf("handler returned unexpected body: got %v want %v", - decodedResponseBody, expectedResponseBody) + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] } + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) + } diff --git a/pkg/app.go b/pkg/app.go index d94d56b5..ea71ce4e 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -498,6 +498,7 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. +// the filter's behavior is like stack func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) return BeeApp diff --git a/pkg/filter.go b/pkg/filter.go index 543d7901..911cb848 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -33,6 +33,7 @@ type FilterFunc func(ctx *context.Context) // when a request with a matching URL arrives. type FilterRouter struct { filterFunc FilterFunc + next *FilterRouter tree *Tree pattern string returnOnOutput bool @@ -81,6 +82,8 @@ func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterPar ctx.Input.SetParam(k, v) } } + } else if f.next != nil { + return f.next.filter(ctx, urlPath, preFilterParams) } if f.returnOnOutput && ctx.ResponseWriter.Started { return true, true diff --git a/pkg/filter_chain_test.go b/pkg/filter_chain_test.go index 42397a60..f1f86088 100644 --- a/pkg/filter_chain_test.go +++ b/pkg/filter_chain_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,7 +39,6 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ctx.Output.Body([]byte("hello")) }) - r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() diff --git a/pkg/hooks.go b/pkg/hooks.go index f511e216..3f778cdc 100644 --- a/pkg/hooks.go +++ b/pkg/hooks.go @@ -111,4 +111,4 @@ func registerCommentRouter() error { } return nil -} \ No newline at end of file +} diff --git a/pkg/httplib/filter.go b/pkg/httplib/filter.go index 72a497d0..5daed64c 100644 --- a/pkg/httplib/filter.go +++ b/pkg/httplib/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/httplib/filter/opentracing/filter.go b/pkg/httplib/filter/opentracing/filter.go index 5f409c63..6cc4d6b0 100644 --- a/pkg/httplib/filter/opentracing/filter.go +++ b/pkg/httplib/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,14 +17,11 @@ package opentracing import ( "context" "net/http" - "strconv" + "github.com/astaxie/beego/pkg/httplib" logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" - - "github.com/astaxie/beego/pkg/httplib" ) type FilterChainBuilder struct { @@ -38,14 +35,8 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { method := req.GetRequest().Method - host := req.GetRequest().URL.Host - path := req.GetRequest().URL.Path - proto := req.GetRequest().Proto - - scheme := req.GetRequest().URL.Scheme - - operationName := host + path + "#" + method + operationName := method + "#" + req.GetRequest().URL.String() span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) defer span.Finish() @@ -54,21 +45,24 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt resp, err := next(spanCtx, req) if resp != nil { - span.SetTag("status", strconv.Itoa(resp.StatusCode)) + span.SetTag("http.status_code", resp.StatusCode) } - - span.SetTag("method", method) - span.SetTag("host", host) - span.SetTag("path", path) - span.SetTag("proto", proto) - span.SetTag("scheme", scheme) - - span.LogFields(log.String("url", req.GetRequest().URL.String())) - + span.SetTag("http.method", method) + span.SetTag("peer.hostname", req.GetRequest().URL.Host) + span.SetTag("http.url", req.GetRequest().URL.String()) + span.SetTag("http.scheme", req.GetRequest().URL.Scheme) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") if err != nil { - span.LogFields(log.String("error", err.Error())) + span.SetTag("error", true) + span.SetTag("message", err.Error()) + } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { + span.SetTag("error", true) } + span.SetTag("peer.address", req.GetRequest().RemoteAddr) + span.SetTag("http.proto", req.GetRequest().Proto) + if builder.CustomSpanFunc != nil { builder.CustomSpanFunc(span, ctx, req, resp, err) } diff --git a/pkg/httplib/filter/opentracing/filter_test.go b/pkg/httplib/filter/opentracing/filter_test.go index aa687541..8849a9ad 100644 --- a/pkg/httplib/filter/opentracing/filter_test.go +++ b/pkg/httplib/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/httplib/filter/prometheus/filter.go b/pkg/httplib/filter/prometheus/filter.go index a0b24d67..e7f7316f 100644 --- a/pkg/httplib/filter/prometheus/filter.go +++ b/pkg/httplib/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -63,11 +63,13 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time host := req.GetRequest().URL.Host path := req.GetRequest().URL.Path - status := resp.StatusCode + status := -1 + if resp != nil { + status = resp.StatusCode + } dur := int(endTime.Sub(startTime) / time.Millisecond) - builder.summaryVec.WithLabelValues(proto, scheme, method, host, path, strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil)) } diff --git a/pkg/httplib/filter/prometheus/filter_test.go b/pkg/httplib/filter/prometheus/filter_test.go index e15d82e5..2964e6c5 100644 --- a/pkg/httplib/filter/prometheus/filter_test.go +++ b/pkg/httplib/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 5f1e3ea3..e9b39a3d 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -357,7 +357,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { kvs := common.NewKVs(params...) var stmtCache *lru.Cache @@ -429,7 +429,6 @@ func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common. al *alias ) - db, err = sql.Open(driverName, dataSource) if err != nil { err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 111657d7..6275cb2a 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -75,7 +75,6 @@ func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) } - func TestDBCache(t *testing.T) { dataBaseCache.add("test1", &alias{}) dataBaseCache.add("default", &alias{}) diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go index 13f8ccde..bb713171 100644 --- a/pkg/orm/db_hints_test.go +++ b/pkg/orm/db_hints_test.go @@ -73,4 +73,4 @@ func TestMaxStmtCacheSize(t *testing.T) { hint := MaxStmtCacheSize(i) assert.Equal(t, hint.GetValue(), i) assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) -} \ No newline at end of file +} diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go index 87b0a2ae..f460794c 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/orm/do_nothing_orm.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import ( var _ Ormer = new(DoNothingOrm) type DoNothingOrm struct { + } func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { @@ -148,19 +149,19 @@ func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOpti return nil, nil } -func (d *DoNothingOrm) DoTx(task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } diff --git a/pkg/orm/do_nothing_omr_test.go b/pkg/orm/do_nothing_orm_test.go similarity index 99% rename from pkg/orm/do_nothing_omr_test.go rename to pkg/orm/do_nothing_orm_test.go index 92cde38b..4d477353 100644 --- a/pkg/orm/do_nothing_omr_test.go +++ b/pkg/orm/do_nothing_orm_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter.go b/pkg/orm/filter.go index d04b8c42..03a30022 100644 --- a/pkg/orm/filter.go +++ b/pkg/orm/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import ( // don't forget to call next(...) inside your Filter type FilterChain func(next Filter) Filter -// Filter's behavior is a little big strang. +// Filter's behavior is a little big strange. // it's only be called when users call methods of Ormer type Filter func(ctx context.Context, inv *Invocation) @@ -31,6 +31,6 @@ var globalFilterChains = make([]FilterChain, 0, 4) // AddGlobalFilterChain adds a new FilterChain // All orm instances built after this invocation will use this filterChain, // but instances built before this invocation will not be affected -func AddGlobalFilterChain(filterChain FilterChain) { - globalFilterChains = append(globalFilterChains, filterChain) -} \ No newline at end of file +func AddGlobalFilterChain(filterChain ...FilterChain) { + globalFilterChains = append(globalFilterChains, filterChain...) +} diff --git a/pkg/orm/filter/opentracing/filter.go b/pkg/orm/filter/opentracing/filter.go index a55ae6d2..405e39ea 100644 --- a/pkg/orm/filter/opentracing/filter.go +++ b/pkg/orm/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package opentracing import ( "context" + "strings" "github.com/opentracing/opentracing-go" @@ -27,6 +28,8 @@ import ( // for example: // if we want to trace QuerySetter // actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// the method Begin*, Commit and Rollback are ignored. +// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { // CustomSpanFunc users are able to custom their span CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) @@ -35,25 +38,34 @@ type FilterChainBuilder struct { func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { return func(ctx context.Context, inv *orm.Invocation) { operationName := builder.operationName(ctx, inv) - span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) - defer span.Finish() - - next(spanCtx, inv) - span.SetTag("Method", inv.Method) - span.SetTag("Table", inv.GetTableName()) - span.SetTag("InsideTx", inv.InsideTx) - span.SetTag("TxName", spanCtx.Value(orm.TxNameKey)) - - if builder.CustomSpanFunc != nil { - builder.CustomSpanFunc(span, spanCtx, inv) + if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" { + next(ctx, inv) + return } + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + next(spanCtx, inv) + builder.buildSpan(span, spanCtx, inv) + } +} + +func (builder *FilterChainBuilder) buildSpan(span opentracing.Span, ctx context.Context, inv *orm.Invocation) { + span.SetTag("orm.method", inv.Method) + span.SetTag("orm.table", inv.GetTableName()) + span.SetTag("orm.insideTx", inv.InsideTx) + span.SetTag("orm.txName", ctx.Value(orm.TxNameKey)) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx, inv) } } func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string { if n, ok := ctx.Value(orm.TxNameKey).(string); ok { - return inv.Method + "#" + n + return inv.Method + "#tx(" + n + ")" } return inv.Method + "#" + inv.GetTableName() } diff --git a/pkg/orm/filter/opentracing/filter_test.go b/pkg/orm/filter/opentracing/filter_test.go index 1428df8a..7df12a92 100644 --- a/pkg/orm/filter/opentracing/filter_test.go +++ b/pkg/orm/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -40,4 +40,4 @@ func TestFilterChainBuilder_FilterChain(t *testing.T) { TxStartTime: time.Now(), } builder.FilterChain(next)(context.Background(), inv) -} \ No newline at end of file +} diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/orm/filter/prometheus/filter.go index 33fdf78f..2e67d85c 100644 --- a/pkg/orm/filter/prometheus/filter.go +++ b/pkg/orm/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter/prometheus/filter_test.go b/pkg/orm/filter/prometheus/filter_test.go index a71e8f50..34766fb4 100644 --- a/pkg/orm/filter/prometheus/filter_test.go +++ b/pkg/orm/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go index eb26ea68..279d299f 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/orm/filter_orm_decorator.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,12 @@ import ( "time" ) -const TxNameKey = "TxName" +const ( + TxNameKey = "TxName" +) + +var _ Ormer = new(filterOrmDecorator) +var _ TxOrmer = new(filterOrmDecorator) type filterOrmDecorator struct { ormer @@ -40,7 +45,7 @@ func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { ormer: delegate, TxBeginner: delegate, root: func(ctx context.Context, inv *Invocation) { - inv.execute() + inv.execute(ctx) }, } @@ -58,7 +63,7 @@ func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrm root: root, insideTx: true, txStartTime: time.Now(), - txName: txName, + txName: txName, } return res } @@ -76,8 +81,8 @@ func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, co mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - err = f.ormer.ReadWithCtx(ctx, md, cols...) + f: func(c context.Context) { + err = f.ormer.ReadWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -98,8 +103,8 @@ func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interf mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - err = f.ormer.ReadForUpdateWithCtx(ctx, md, cols...) + f: func(c context.Context) { + err = f.ormer.ReadForUpdateWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -125,8 +130,8 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - ok, res, err = f.ormer.ReadOrCreateWithCtx(ctx, md, col1, cols...) + f: func(c context.Context) { + ok, res, err = f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...) }, } f.root(ctx, inv) @@ -151,8 +156,8 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.LoadRelatedWithCtx(ctx, md, name, args...) + f: func(c context.Context) { + res, err = f.ormer.LoadRelatedWithCtx(c, md, name, args...) }, } f.root(ctx, inv) @@ -176,8 +181,8 @@ func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{} mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res = f.ormer.QueryM2MWithCtx(ctx, md, name) + f: func(c context.Context) { + res = f.ormer.QueryM2MWithCtx(c, md, name) }, } f.root(ctx, inv) @@ -190,10 +195,10 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { var ( - res QuerySeter + res QuerySeter name string - md interface{} - mi *modelInfo + md interface{} + mi *modelInfo ) if table, ok := ptrStructOrTableName.(string); ok { @@ -212,10 +217,10 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - Md: md, - mi: mi, - f: func() { - res = f.ormer.QueryTableWithCtx(ctx, ptrStructOrTableName) + Md: md, + mi: mi, + f: func(c context.Context) { + res = f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) }, } f.root(ctx, inv) @@ -230,7 +235,7 @@ func (f *filterOrmDecorator) DBStats() *sql.DBStats { Method: "DBStats", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { + f: func(c context.Context) { res = f.ormer.DBStats() }, } @@ -255,8 +260,8 @@ func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertWithCtx(ctx, md) + f: func(c context.Context) { + res, err = f.ormer.InsertWithCtx(c, md) }, } f.root(ctx, inv) @@ -280,8 +285,8 @@ func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md inter mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertOrUpdateWithCtx(ctx, md, colConflitAndArgs...) + f: func(c context.Context) { + res, err = f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...) }, } f.root(ctx, inv) @@ -316,8 +321,8 @@ func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, m mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertMultiWithCtx(ctx, bulk, mds) + f: func(c context.Context) { + res, err = f.ormer.InsertMultiWithCtx(c, bulk, mds) }, } f.root(ctx, inv) @@ -341,8 +346,8 @@ func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.UpdateWithCtx(ctx, md, cols...) + f: func(c context.Context) { + res, err = f.ormer.UpdateWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -366,8 +371,8 @@ func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.DeleteWithCtx(ctx, md, cols...) + f: func(c context.Context) { + res, err = f.ormer.DeleteWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -387,8 +392,8 @@ func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args Args: []interface{}{query, args}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res = f.ormer.RawWithCtx(ctx, query, args...) + f: func(c context.Context) { + res = f.ormer.RawWithCtx(c, query, args...) }, } f.root(ctx, inv) @@ -403,7 +408,7 @@ func (f *filterOrmDecorator) Driver() Driver { Method: "Driver", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { + f: func(c context.Context) { res = f.ormer.Driver() }, } @@ -433,28 +438,28 @@ func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql. Args: []interface{}{opts}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.TxBeginner.BeginWithCtxAndOpts(ctx, opts) - res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(ctx)) + f: func(c context.Context) { + res, err = f.TxBeginner.BeginWithCtxAndOpts(c, opts) + res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c)) }, } f.root(ctx, inv) return res, err } -func (f *filterOrmDecorator) DoTx(task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(context.Background(), nil, task) } -func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(ctx, nil, task) } -func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(context.Background(), opts, task) } -func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { var ( err error ) @@ -465,8 +470,8 @@ func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.T InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: getTxNameFromCtx(ctx), - f: func() { - err = f.TxBeginner.DoTxWithCtxAndOpts(ctx, opts, task) + f: func(c context.Context) { + err = doTxTemplate(f, c, opts, task) }, } f.root(ctx, inv) @@ -483,7 +488,7 @@ func (f *filterOrmDecorator) Commit() error { InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func() { + f: func(c context.Context) { err = f.TxCommitter.Commit() }, } @@ -501,7 +506,7 @@ func (f *filterOrmDecorator) Rollback() error { InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func() { + f: func(c context.Context) { err = f.TxCommitter.Rollback() }, } @@ -516,4 +521,3 @@ func getTxNameFromCtx(ctx context.Context) string { } return txName } - diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index d1099eaf..4e837a4e 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,49 +130,49 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) { - assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) - assert.Equal(t, 2, len(inv.Args)) - assert.Equal(t, "", inv.GetTableName()) - assert.False(t, inv.InsideTx) + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + } + next(ctx, inv) } }) - err := od.DoTx(func(txOrm TxOrmer) error { - return errors.New("tx error") + err := od.DoTx(func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx error", err.Error()) - err = od.DoTxWithCtx(context.Background(), func(txOrm TxOrmer) error { - return errors.New("tx ctx error") + err = od.DoTxWithCtx(context.Background(), func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx ctx error", err.Error()) - err = od.DoTxWithOpts(nil, func(txOrm TxOrmer) error { - return errors.New("tx opts error") + err = od.DoTxWithOpts(nil, func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx opts error", err.Error()) + od = NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) { - assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) - assert.Equal(t, 2, len(inv.Args)) - assert.Equal(t, "", inv.GetTableName()) - assert.Equal(t, "do tx name", inv.TxName) - assert.False(t, inv.InsideTx) + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.Equal(t, "do tx name", inv.TxName) + assert.False(t, inv.InsideTx) + } next(ctx, inv) } }) ctx := context.WithValue(context.Background(), TxNameKey, "do tx name") - err = od.DoTxWithCtxAndOpts(ctx, nil, func(txOrm TxOrmer) error { - return errors.New("tx ctx opts error") + err = od.DoTxWithCtxAndOpts(ctx, nil, func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx ctx opts error", err.Error()) } func TestFilterOrmDecorator_Driver(t *testing.T) { @@ -347,6 +347,8 @@ func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) { assert.Equal(t, int64(13), i) } +var _ Ormer = new(filterMockOrm) + // filterMockOrm is only used in this test file type filterMockOrm struct { DoNothingOrm @@ -376,8 +378,8 @@ func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int6 return 100, errors.New("insert error") } -func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { - return task(nil) +func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(c context.Context, txOrm TxOrmer) error) error { + return task(ctx, nil) } func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { diff --git a/pkg/orm/filter_test.go b/pkg/orm/filter_test.go index 0f2944c7..b2ca4ae1 100644 --- a/pkg/orm/filter_test.go +++ b/pkg/orm/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,4 +28,5 @@ func TestAddGlobalFilterChain(t *testing.T) { } }) assert.Equal(t, 1, len(globalFilterChains)) + globalFilterChains = nil } diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go index 1c9fee09..e935b7ea 100644 --- a/pkg/orm/invocation.go +++ b/pkg/orm/invocation.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ package orm import ( + "context" "time" ) @@ -22,27 +23,27 @@ import ( type Invocation struct { Method string // Md may be nil in some cases. It depends on method - Md interface{} + Md interface{} // the args are all arguments except context.Context - Args []interface{} + Args []interface{} mi *modelInfo // f is the Orm operation - f func() + f func(ctx context.Context) // insideTx indicates whether this is inside a transaction - InsideTx bool + InsideTx bool TxStartTime time.Time - TxName string + TxName string } func (inv *Invocation) GetTableName() string { - if inv.mi != nil{ + if inv.mi != nil { return inv.mi.table } return "" } -func (inv *Invocation) execute() { - inv.f() +func (inv *Invocation) execute(ctx context.Context) { + inv.f(ctx) } diff --git a/pkg/orm/model_utils_test.go b/pkg/orm/model_utils_test.go index ea38d90a..b65aadcb 100644 --- a/pkg/orm/model_utils_test.go +++ b/pkg/orm/model_utils_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index ae166dc7..09ef4f15 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -27,7 +27,6 @@ import ( _ "github.com/mattn/go-sqlite3" // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" - ) // A slice string field. diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index d79053af..cc678fc8 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -522,19 +522,24 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO return taskTxOrm, nil } -func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { +func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtx(context.Background(), task) } -func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtxAndOpts(ctx, nil, task) } -func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtxAndOpts(context.Background(), opts, task) } -func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return doTxTemplate(o, ctx, opts, task) +} + +func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions, + task func(ctx context.Context, txOrm TxOrmer) error) error { _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) if err != nil { return err @@ -553,9 +558,8 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task } } }() - var taskTxOrm = _txOrm - err = task(taskTxOrm) + err = task(ctx, taskTxOrm) panicked = false return err } @@ -582,18 +586,11 @@ func NewOrm() Ormer { // NewOrmUsingDB create new orm with the name func NewOrmUsingDB(aliasName string) Ormer { - o := new(orm) if al, ok := dataBaseCache.get(aliasName); ok { - o.alias = al - if Debug { - o.db = newDbQueryLog(al, al.DB) - } else { - o.db = al.DB - } + return newDBWithAlias(al) } else { panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } - return o } // NewOrmWithDB create a new ormer object with specify *sql.DB for query @@ -603,14 +600,21 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) return nil, err } + return newDBWithAlias(al), nil +} + +func newDBWithAlias(al *alias) Ormer { o := new(orm) o.alias = al if Debug { - o.db = newDbQueryLog(o.alias, db) + o.db = newDbQueryLog(al, al.DB) } else { - o.db = db + o.db = al.DB } - return o, nil + if len(globalFilterChains) > 0 { + return NewFilterOrmDecorator(o, globalFilterChains...) + } + return o } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index f5242a46..e3dafecd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2486,5 +2486,3 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) } } - - diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 9624fd94..59688588 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -95,10 +95,10 @@ type TxBeginner interface { BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) //closure control transaction - DoTx(task func(txOrm TxOrmer) error) error - DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error - DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error - DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error + DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error } type TxCommitter interface { diff --git a/pkg/router.go b/pkg/router.go index 8caba94a..6b25d7e3 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -468,12 +468,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { root := p.chainRoot filterFunc := chain(root.filterFunc) p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) -} + p.chainRoot.next = root +} // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go index 64b8d94a..a27d30a6 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/session/sess_file_test.go @@ -57,7 +57,7 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) exists, err := fp.SessionExist(sid) - if err != nil{ + if err != nil { t.Error(err) } if exists { diff --git a/pkg/web/doc.go b/pkg/web/doc.go index 2001f4ca..1425a729 100644 --- a/pkg/web/doc.go +++ b/pkg/web/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go index 822d5e4d..e6ee9150 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/web/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ type FilterChainBuilder struct { CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) } - func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { return func(ctx *beegoCtx.Context) { var ( @@ -55,9 +54,21 @@ func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.Filt next(ctx) // if you think we need to do more things, feel free to create an issue to tell us - span.SetTag("status", ctx.Output.Status) - span.SetTag("method", ctx.Input.Method()) - span.SetTag("route", ctx.Input.GetData("RouterPattern")) + span.SetTag("http.status_code", ctx.ResponseWriter.Status) + span.SetTag("http.method", ctx.Input.Method()) + span.SetTag("peer.hostname", ctx.Request.Host) + span.SetTag("http.url", ctx.Request.URL.String()) + span.SetTag("http.scheme", ctx.Request.URL.Scheme) + span.SetTag("span.kind", "server") + span.SetTag("component", "beego") + if ctx.Output.IsServerError() || ctx.Output.IsClientError() { + span.SetTag("error", true) + } + span.SetTag("peer.address", ctx.Request.RemoteAddr) + span.SetTag("http.proto", ctx.Request.Proto) + + span.SetTag("beego.route", ctx.Input.GetData("RouterPattern")) + if builder.CustomSpanFunc != nil { builder.CustomSpanFunc(span, ctx) } @@ -70,7 +81,7 @@ func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string { // TODO, if we support multiple servers, this need to be changed route, found := beego.BeeApp.Handlers.FindRouter(ctx) if found { - operationName = route.GetPattern() + operationName = ctx.Input.Method() + "#" + route.GetPattern() } return operationName } diff --git a/pkg/web/filter/opentracing/filter_test.go b/pkg/web/filter/opentracing/filter_test.go index 65f1f24e..750ea7a9 100644 --- a/pkg/web/filter/opentracing/filter_test.go +++ b/pkg/web/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/prometheus/filter.go b/pkg/web/filter/prometheus/filter.go index bd47dcec..8f4b46e3 100644 --- a/pkg/web/filter/prometheus/filter.go +++ b/pkg/web/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/prometheus/filter_test.go b/pkg/web/filter/prometheus/filter_test.go index 7d2e2acf..822892bc 100644 --- a/pkg/web/filter/prometheus/filter_test.go +++ b/pkg/web/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From ce698aacf6a5b7d2ebe3ccc25204b65cc0fdeeb3 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 11 Aug 2020 12:06:02 +0800 Subject: [PATCH 088/207] rm some methods --- pkg/common/kv.go | 19 ------------------- pkg/common/kv_test.go | 10 ++++------ 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 26e786f9..80797aa9 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -41,8 +41,6 @@ type KVs interface { GetValueOr(key interface{}, defValue interface{}) interface{} Contains(key interface{}) bool IfContains(key interface{}, action func(value interface{})) KVs - Put(key interface{}, value interface{}) KVs - Clone() KVs } // SimpleKVs will store SimpleKV collection as map @@ -77,23 +75,6 @@ func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{}) return kvs } -// Put stores the value -func (kvs *SimpleKVs) Put(key interface{}, value interface{}) KVs { - kvs.kvs[key] = value - return kvs -} - -// Clone -func (kvs *SimpleKVs) Clone() KVs { - newKVs := new(SimpleKVs) - - for key, value := range kvs.kvs { - newKVs.Put(key, value) - } - - return newKVs -} - // NewKVs creates the *KVs instance func NewKVs(kvs ...KV) KVs { res := &SimpleKVs{ diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index 275c6753..7b52a300 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -29,12 +29,10 @@ func TestKVs(t *testing.T) { assert.True(t, kvs.Contains(key)) - kvs.IfContains(key, func(value interface{}) { - kvs.Put("my-key1", "") - }) - - assert.True(t, kvs.Contains("my-key1")) - v := kvs.GetValueOr(key, 13) assert.Equal(t, 12, v) + + v = kvs.GetValueOr(`key-not-exists`, 8546) + assert.Equal(t, 8546, v) + } From 9ca9535c48b38387a7af5ff5d8c5235465ba7273 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 11 Aug 2020 16:53:31 +0800 Subject: [PATCH 089/207] fix:return error after inserting data when primary key is string --- pkg/orm/db.go | 7 ++++++- pkg/orm/db_oracle.go | 7 ++++++- pkg/orm/models_test.go | 5 +++++ pkg/orm/orm_test.go | 19 +++++++++++++++++++ 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/pkg/orm/db.go b/pkg/orm/db.go index 9a1827e8..9024cf01 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -484,7 +484,12 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s if isMulti { return res.RowsAffected() } - return res.LastInsertId() + + lastInsertId, err:=res.LastInsertId() + if err != nil { + DebugLog.Println("[WARN] return LastInsertId error:", err) + } + return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index 5d121f83..c8e71849 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -126,7 +126,12 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam if isMulti { return res.RowsAffected() } - return res.LastInsertId() + + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println("[WARN] return LastInsertId error:", err) + } + return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index ae166dc7..3c7bbbc8 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -412,6 +412,11 @@ type PtrPk struct { Positive bool } +type StrPk struct { + Id string `orm:"column(id);size(64);pk"` + Value string +} + var DBARGS = struct { Driver string Source string diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index f5242a46..1a8a88ea 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -200,6 +200,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(StrPk)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -224,6 +225,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(StrPk)) BootStrap() @@ -2487,4 +2489,21 @@ func TestInsertOrUpdate(t *testing.T) { } } +func TestStrPkInsert(t *testing.T) { + RegisterModel(new(StrPk)) + value := `StrPkValues(*56` + strPk := &StrPk{ + Id: "1", + Value: value, + } + + var err error + _, err = dORM.Insert(strPk) + throwFailNow(t, AssertIs(err, nil)) + + var vForTesting StrPk + err = dORM.QueryTable(new(StrPk)).Filter(`id`, `1`).One(&vForTesting) + throwFailNow(t, AssertIs(err, nil)) + throwFailNow(t, AssertIs(vForTesting.Value, value)) +} From 7267f5e573daa593b48954acd6002cedce9467e3 Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Tue, 11 Aug 2020 16:09:29 +0200 Subject: [PATCH 090/207] Add new config option into provider struct --- pkg/session/redis/sess_redis.go | 31 +++++++------- pkg/session/redis_cluster/redis_cluster.go | 36 ++++++++-------- .../redis_sentinel/sess_redis_sentinel.go | 42 +++++++++---------- 3 files changed, 54 insertions(+), 55 deletions(-) diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index 7e991ef5..100f9e1e 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -109,12 +109,15 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { // Provider redis session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Client + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + idleTimeout time.Duration + idleCheckFrequency time.Duration + maxRetries int + poollist *redis.Client } // SessionInit init redis session @@ -149,25 +152,22 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - var idleTimeout time.Duration = 0 if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second + rp.idleTimeout = time.Duration(timeout) * time.Second } } - var idleCheckFrequency time.Duration = 0 if len(configs) > 5 { checkFrequency, err := strconv.Atoi(configs[5]) if err == nil && checkFrequency > 0 { - idleCheckFrequency = time.Duration(checkFrequency) * time.Second + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second } } - var maxRetries = 0 if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - maxRetries = retries + rp.maxRetries = retries } } @@ -176,9 +176,9 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { Password: rp.password, PoolSize: rp.poolsize, DB: rp.dbNum, - IdleTimeout: idleTimeout, - IdleCheckFrequency: idleCheckFrequency, - MaxRetries: maxRetries, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.maxRetries, }) return rp.poollist.Ping().Err() @@ -249,4 +249,3 @@ func (rp *Provider) SessionAll() int { func init() { session.Register("redis", redispder) } - diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index 75dc0e63..d6f051c1 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -107,12 +107,15 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { // Provider redis_cluster session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *rediss.ClusterClient + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + idleTimeout time.Duration + idleCheckFrequency time.Duration + maxRetries int + poollist *rediss.ClusterClient } // SessionInit init redis_cluster session @@ -147,35 +150,32 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - var idleTimeout time.Duration = 0 if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second + rp.idleTimeout = time.Duration(timeout) * time.Second } } - var idleCheckFrequency time.Duration = 0 if len(configs) > 5 { checkFrequency, err := strconv.Atoi(configs[5]) if err == nil && checkFrequency > 0 { - idleCheckFrequency = time.Duration(checkFrequency) * time.Second + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second } } - var maxRetries = 0 if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - maxRetries = retries + rp.maxRetries = retries } } rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - IdleTimeout: idleTimeout, - IdleCheckFrequency: idleCheckFrequency, - MaxRetries: maxRetries, + Addrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.maxRetries, }) return rp.poollist.Ping().Err() } diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go index da287b8d..67790096 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -107,13 +107,16 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { // Provider redis_sentinel session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Client - masterName string + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + idleTimeout time.Duration + idleCheckFrequency time.Duration + maxRetries int + poollist *redis.Client + masterName string } // SessionInit init redis_sentinel session @@ -157,37 +160,34 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.masterName = "mymaster" } - var idleTimeout time.Duration = 0 if len(configs) > 5 { timeout, err := strconv.Atoi(configs[4]) if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second + rp.idleTimeout = time.Duration(timeout) * time.Second } } - var idleCheckFrequency time.Duration = 0 if len(configs) > 6 { checkFrequency, err := strconv.Atoi(configs[5]) if err == nil && checkFrequency > 0 { - idleCheckFrequency = time.Duration(checkFrequency) * time.Second + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second } } - var maxRetries = 0 if len(configs) > 7 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - maxRetries = retries + rp.maxRetries = retries } } rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - IdleTimeout: idleTimeout, - IdleCheckFrequency: idleCheckFrequency, - MaxRetries: maxRetries, + SentinelAddrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + MasterName: rp.masterName, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.maxRetries, }) return rp.poollist.Ping().Err() From 813a4df3c53b5f14453dec6753d121ddfb188f8a Mon Sep 17 00:00:00 2001 From: Phillip Stagnet Date: Tue, 11 Aug 2020 16:21:43 +0200 Subject: [PATCH 091/207] Make sure expiry time is in seconds --- pkg/session/redis/sess_redis.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index 100f9e1e..6e1fbae6 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -224,7 +224,7 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)) + c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) } return rp.SessionRead(sid) } From 7ce0fde171192624f108cf803813dc6251048ba7 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Thu, 13 Aug 2020 19:14:00 +0800 Subject: [PATCH 092/207] fix:return error when calling ``InsertOrUpdate`` is successful with string primary key --- pkg/orm/db.go | 9 +++++++-- pkg/orm/db_mysql.go | 7 ++++++- pkg/orm/orm_test.go | 20 +++++++++++++++++--- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/pkg/orm/db.go b/pkg/orm/db.go index dc4b5a3f..dea8845f 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -486,7 +486,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s return res.RowsAffected() } - lastInsertId, err:=res.LastInsertId() + lastInsertId, err := res.LastInsertId() if err != nil { DebugLog.Println("[WARN] return LastInsertId error:", err) } @@ -591,7 +591,12 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a if isMulti { return res.RowsAffected() } - return res.LastInsertId() + + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println("[WARN] return LastInsertId error:", err) + } + return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go index 6e99058e..0db26e5b 100644 --- a/pkg/orm/db_mysql.go +++ b/pkg/orm/db_mysql.go @@ -164,7 +164,12 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val if isMulti { return res.RowsAffected() } - return res.LastInsertId() + + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println("[WARN] return LastInsertId error:", err) + } + return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index fd752a94..d4a47686 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2528,19 +2528,33 @@ func TestInsertOrUpdate(t *testing.T) { func TestStrPkInsert(t *testing.T) { RegisterModel(new(StrPk)) + pk := `1` value := `StrPkValues(*56` strPk := &StrPk{ - Id: "1", + Id: pk, Value: value, } var err error _, err = dORM.Insert(strPk) throwFailNow(t, AssertIs(err, nil)) - + var vForTesting StrPk - err = dORM.QueryTable(new(StrPk)).Filter(`id`, `1`).One(&vForTesting) + err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting) throwFailNow(t, AssertIs(err, nil)) throwFailNow(t, AssertIs(vForTesting.Value, value)) + + value2 := `s8s5da7as` + strPkForUpsert := &StrPk{ + Id: pk, + Value: value2, + } + _, err = dORM.InsertOrUpdate(strPkForUpsert, `id`) + throwFailNow(t, AssertIs(err, nil)) + + var vForTesting2 StrPk + err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting2) + throwFailNow(t, AssertIs(err, nil)) + throwFailNow(t, AssertIs(vForTesting2.Value, value2)) } From bdec93986be1b8aafbf922ce563139c3ce595a75 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 13 Aug 2020 14:14:10 +0800 Subject: [PATCH 093/207] Bean: Support autowire by tag Orm: Support default value filter --- go.mod | 1 + go.sum | 5 + pkg/bean/context.go | 21 ++ pkg/bean/doc.go | 17 ++ pkg/bean/factory.go | 25 ++ pkg/bean/metadata.go | 28 +++ pkg/bean/tag_auto_wire_bean_factory.go | 231 ++++++++++++++++++ pkg/bean/tag_auto_wire_bean_factory_test.go | 75 ++++++ pkg/bean/time_type_adapter.go | 36 +++ pkg/bean/time_type_adapter_test.go | 29 +++ pkg/bean/type_adapter.go | 26 ++ pkg/orm/filter/bean/default_value_filter.go | 136 +++++++++++ .../filter/bean/default_value_filter_test.go | 73 ++++++ pkg/orm/invocation.go | 9 + 14 files changed, 712 insertions(+) create mode 100644 pkg/bean/context.go create mode 100644 pkg/bean/doc.go create mode 100644 pkg/bean/factory.go create mode 100644 pkg/bean/metadata.go create mode 100644 pkg/bean/tag_auto_wire_bean_factory.go create mode 100644 pkg/bean/tag_auto_wire_bean_factory_test.go create mode 100644 pkg/bean/time_type_adapter.go create mode 100644 pkg/bean/time_type_adapter_test.go create mode 100644 pkg/bean/type_adapter.go create mode 100644 pkg/orm/filter/bean/default_value_filter.go create mode 100644 pkg/orm/filter/bean/default_value_filter_test.go diff --git a/go.mod b/go.mod index 3ad8576a..b3f2c2e7 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( golang.org/x/tools v0.0.0-20200117065230-39095c1d176c google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 + github.com/pkg/errors v0.9.1 ) replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 diff --git a/go.sum b/go.sum index 12b76333..a54afad6 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -185,6 +187,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,6 +222,8 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/pkg/bean/context.go b/pkg/bean/context.go new file mode 100644 index 00000000..93261628 --- /dev/null +++ b/pkg/bean/context.go @@ -0,0 +1,21 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// ApplicationContext define for future +// when we decide to support DI, IoC, this will be core API +type ApplicationContext interface { + +} diff --git a/pkg/bean/doc.go b/pkg/bean/doc.go new file mode 100644 index 00000000..212e8aaf --- /dev/null +++ b/pkg/bean/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// bean is a basic package +// it should not depend on other modules except common module, log module and config module +package bean diff --git a/pkg/bean/factory.go b/pkg/bean/factory.go new file mode 100644 index 00000000..698474c4 --- /dev/null +++ b/pkg/bean/factory.go @@ -0,0 +1,25 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// AutoWireBeanFactory wire the bean based on ApplicationContext and context.Context +type AutoWireBeanFactory interface { + // AutoWire will wire the bean. + AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error +} \ No newline at end of file diff --git a/pkg/bean/metadata.go b/pkg/bean/metadata.go new file mode 100644 index 00000000..8c423692 --- /dev/null +++ b/pkg/bean/metadata.go @@ -0,0 +1,28 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// BeanMetadata, in other words, bean's config. +// it could be read from config file +type BeanMetadata struct { + // Fields: field name => field metadata + Fields map[string]*FieldMetadata +} + +// FieldMetadata contains metadata +type FieldMetadata struct { + // default value in string format + DftValue string +} diff --git a/pkg/bean/tag_auto_wire_bean_factory.go b/pkg/bean/tag_auto_wire_bean_factory.go new file mode 100644 index 00000000..ea8fd907 --- /dev/null +++ b/pkg/bean/tag_auto_wire_bean_factory.go @@ -0,0 +1,231 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "fmt" + "reflect" + "strconv" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/logs" +) + +const DefaultValueTagKey = "default" + +// TagAutoWireBeanFactory wire the bean based on Fields' tag +// if field's value is "zero value", we will execute injection +// see reflect.Value.IsZero() +// If field's kind is one of(reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice +// reflect.UnsafePointer, reflect.Array, reflect.Uintptr, reflect.Complex64, reflect.Complex128 +// reflect.Ptr, reflect.Struct), +// it will be ignored +type TagAutoWireBeanFactory struct { + // we allow user register their TypeAdapter + Adapters map[string]TypeAdapter + + // FieldTagParser is an extension point which means that you can custom how to read field's metadata from tag + FieldTagParser func(field reflect.StructField) *FieldMetadata +} + +// NewTagAutoWireBeanFactory create an instance of TagAutoWireBeanFactory +// by default, we register Time adapter, the time will be parse by using layout "2006-01-02 15:04:05" +// If you need more adapter, you can implement interface TypeAdapter +func NewTagAutoWireBeanFactory() *TagAutoWireBeanFactory { + return &TagAutoWireBeanFactory{ + Adapters: map[string]TypeAdapter{ + "Time": &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"}, + }, + + FieldTagParser: func(field reflect.StructField) *FieldMetadata { + return &FieldMetadata{ + DftValue: field.Tag.Get(DefaultValueTagKey), + } + }, + } +} + +// AutoWire use value from appCtx to wire the bean, or use default value, or do nothing +func (t *TagAutoWireBeanFactory) AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error { + if bean == nil { + return nil + } + + v := reflect.Indirect(reflect.ValueOf(bean)) + + bm := t.getConfig(v) + + // field name, field metadata + for fn, fm := range bm.Fields { + + fValue := v.FieldByName(fn) + if len(fm.DftValue) == 0 || !t.needInject(fValue) || !fValue.CanSet() { + continue + } + + // handle type adapter + typeName := fValue.Type().Name() + if adapter, ok := t.Adapters[typeName]; ok { + dftValue, err := adapter.DefaultValue(ctx, fm.DftValue) + if err == nil { + fValue.Set(reflect.ValueOf(dftValue)) + continue + } else { + return err + } + } + + switch fValue.Kind() { + case reflect.Bool: + if v, err := strconv.ParseBool(fm.DftValue); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to bool value", + fn, fm.DftValue)) + } else { + fValue.SetBool(v) + continue + } + case reflect.Int: + if err := t.setIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + continue + case reflect.Int8: + if err := t.setIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + case reflect.Int16: + if err := t.setIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int32: + if err := t.setIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int64: + if err := t.setIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint: + if err := t.setUIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + + case reflect.Uint8: + if err := t.setUIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint16: + if err := t.setUIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + case reflect.Uint32: + if err := t.setUIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint64: + if err := t.setUIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Float32: + if err := t.setFloatXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + case reflect.Float64: + if err := t.setFloatXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.String: + fValue.SetString(fm.DftValue) + continue + + // case reflect.Ptr: + // case reflect.Struct: + default: + logs.Warn("this field[%s] has default setting, but we don't support this type: %s", + fn, fValue.Kind().String()) + } + } + return nil +} + +func (t *TagAutoWireBeanFactory) setFloatXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseFloat(dftValue, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to float%d value", + fn, dftValue, bitSize)) + } else { + fv.SetFloat(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setUIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseUint(dftValue, 10, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to uint%d value", + fn, dftValue, bitSize)) + } else { + fv.SetUint(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseInt(dftValue, 10, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to int%d value", + fn, dftValue, bitSize)) + } else { + fv.SetInt(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) needInject(fValue reflect.Value) bool { + return fValue.IsZero() +} + +// getConfig never return nil +func (t *TagAutoWireBeanFactory) getConfig(beanValue reflect.Value) *BeanMetadata { + fms := make(map[string]*FieldMetadata, beanValue.NumField()) + for i := 0; i < beanValue.NumField(); i++ { + // f => StructField + f := beanValue.Type().Field(i) + fms[f.Name] = t.FieldTagParser(f) + } + return &BeanMetadata{ + Fields: fms, + } +} diff --git a/pkg/bean/tag_auto_wire_bean_factory_test.go b/pkg/bean/tag_auto_wire_bean_factory_test.go new file mode 100644 index 00000000..2d83c537 --- /dev/null +++ b/pkg/bean/tag_auto_wire_bean_factory_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) { + factory := NewTagAutoWireBeanFactory() + bm := &ComplicateStruct{} + err := factory.AutoWire(context.Background(), nil, bm) + assert.Nil(t, err) + assert.Equal(t, 12, bm.IntValue) + assert.Equal(t, "hello, strValue", bm.StrValue) + + assert.Equal(t, int8(8), bm.Int8Value) + assert.Equal(t, int16(16), bm.Int16Value) + assert.Equal(t, int32(32), bm.Int32Value) + assert.Equal(t, int64(64), bm.Int64Value) + + assert.Equal(t, uint(13), bm.UintValue) + assert.Equal(t, uint8(88), bm.Uint8Value) + assert.Equal(t, uint16(1616), bm.Uint16Value) + assert.Equal(t, uint32(3232), bm.Uint32Value) + assert.Equal(t, uint64(6464), bm.Uint64Value) + + assert.Equal(t, float32(32.32), bm.Float32Value) + assert.Equal(t, float64(64.64), bm.Float64Value) + + assert.True(t, bm.BoolValue) + assert.Equal(t, 0, bm.ignoreInt) + + assert.NotNil(t, bm.TimeValue) +} + +type ComplicateStruct struct { + IntValue int `default:"12"` + StrValue string `default:"hello, strValue"` + Int8Value int8 `default:"8"` + Int16Value int16 `default:"16"` + Int32Value int32 `default:"32"` + Int64Value int64 `default:"64"` + + UintValue uint `default:"13"` + Uint8Value uint8 `default:"88"` + Uint16Value uint16 `default:"1616"` + Uint32Value uint32 `default:"3232"` + Uint64Value uint64 `default:"6464"` + + Float32Value float32 `default:"32.32"` + Float64Value float64 `default:"64.64"` + + BoolValue bool `default:"true"` + + ignoreInt int `default:"11"` + + TimeValue time.Time `default:"2018-02-03 12:13:14.000"` +} diff --git a/pkg/bean/time_type_adapter.go b/pkg/bean/time_type_adapter.go new file mode 100644 index 00000000..846eb694 --- /dev/null +++ b/pkg/bean/time_type_adapter.go @@ -0,0 +1,36 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "time" + +) + +// TimeTypeAdapter process the time.Time +type TimeTypeAdapter struct { + Layout string +} + +// DefaultValue parse the DftValue to time.Time +// and if the DftValue == now +// time.Now() is returned +func (t *TimeTypeAdapter) DefaultValue(ctx context.Context, dftValue string) (interface{}, error) { + if dftValue == "now"{ + return time.Now(), nil + } + return time.Parse(t.Layout, dftValue) +} diff --git a/pkg/bean/time_type_adapter_test.go b/pkg/bean/time_type_adapter_test.go new file mode 100644 index 00000000..9c097048 --- /dev/null +++ b/pkg/bean/time_type_adapter_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTimeTypeAdapter_DefaultValue(t *testing.T) { + typeAdapter := &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"} + tm, err := typeAdapter.DefaultValue(context.Background(), "2018-02-03 12:34:11") + assert.Nil(t, err) + assert.NotNil(t, tm) +} diff --git a/pkg/bean/type_adapter.go b/pkg/bean/type_adapter.go new file mode 100644 index 00000000..ba675b64 --- /dev/null +++ b/pkg/bean/type_adapter.go @@ -0,0 +1,26 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// TypeAdapter is an abstraction that define some behavior of target type +// usually, we don't use this to support basic type since golang has many restriction for basic types +// This is an important extension point +type TypeAdapter interface { + DefaultValue(ctx context.Context, dftValue string) (interface{}, error) +} diff --git a/pkg/orm/filter/bean/default_value_filter.go b/pkg/orm/filter/bean/default_value_filter.go new file mode 100644 index 00000000..80aef43d --- /dev/null +++ b/pkg/orm/filter/bean/default_value_filter.go @@ -0,0 +1,136 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "reflect" + "strings" + + "github.com/astaxie/beego/pkg/bean" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/orm" +) + +// DefaultValueFilterChainBuilder only works for InsertXXX method, +// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods. +// so we won't handle those two methods unless you set includeInsertOrUpdate to true +// And if the element is not pointer, this filter doesn't work +type DefaultValueFilterChainBuilder struct { + factory bean.AutoWireBeanFactory + compatibleWithOldStyle bool + + // only the includeInsertOrUpdate is true, this filter will handle those two methods + includeInsertOrUpdate bool +} + +// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder +// In beego v1.x, the default value config looks like orm:default(xxxx) +// But the default value in 2.x is default:xxx +// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle +func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter, + includeInsertOrUpdate bool, + compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder { + factory := bean.NewTagAutoWireBeanFactory() + + if compatibleWithOldStyle { + newParser := factory.FieldTagParser + factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata { + if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" { + return newParser(field) + } else { + res := &bean.FieldMetadata{} + ormMeta := field.Tag.Get("orm") + ormMetaParts := strings.Split(ormMeta, ";") + for _, p := range ormMetaParts { + if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") { + res.DftValue = p[8 : len(p)-1] + } + } + return res + } + } + } + + for k, v := range typeAdapters { + factory.Adapters[k] = v + } + + return &DefaultValueFilterChainBuilder{ + factory: factory, + compatibleWithOldStyle: compatibleWithOldStyle, + includeInsertOrUpdate: includeInsertOrUpdate, + } +} + +func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) { + switch inv.Method { + case "Insert", "InsertWithCtx": + d.handleInsert(ctx, inv) + break + case "InsertOrUpdate", "InsertOrUpdateWithCtx": + d.handleInsertOrUpdate(ctx, inv) + break + case "InsertMulti", "InsertMultiWithCtx": + d.handleInsertMulti(ctx, inv) + break + } + next(ctx, inv) + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) { + d.setDefaultValue(ctx, inv.Args[0]) +} + +func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) { + if d.includeInsertOrUpdate { + ins := inv.Args[0] + if ins == nil { + return + } + + pkName := inv.GetPkFieldName() + pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName) + + if pkField.IsZero() { + d.setDefaultValue(ctx, ins) + } + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) { + mds := inv.Args[1] + + if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice { + // do nothing + return + } + + mdsArr := reflect.Indirect(reflect.ValueOf(mds)) + for i := 0; i < mdsArr.Len(); i++ { + d.setDefaultValue(ctx, mdsArr.Index(i).Interface()) + } + logs.Warn("%v", mdsArr.Index(0).Interface()) +} + +func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) { + err := d.factory.AutoWire(ctx, nil, ins) + if err != nil { + logs.Error("try to wire the bean for orm.Insert failed. "+ + "the default value is not set: %v, ", err) + } +} diff --git a/pkg/orm/filter/bean/default_value_filter_test.go b/pkg/orm/filter/bean/default_value_filter_test.go new file mode 100644 index 00000000..6b038f27 --- /dev/null +++ b/pkg/orm/filter/bean/default_value_filter_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/orm" +) + +func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { + builder := NewDefaultValueFilterChainBuilder(nil, true, true) + o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain) + + // test insert + entity := &DefaultValueTestEntity{} + _, _ = o.Insert(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + assert.Equal(t, 0, entity.AgeIgnore) + + // test InsertOrUpdate + entity = &DefaultValueTestEntity{} + orm.RegisterModel(entity) + + + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + + // we won't set the default value because we find the pk is not Zero value + entity.Id = 3 + entity.AgeInOldStyle = 0 + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 0, entity.AgeInOldStyle) + + entity = &DefaultValueTestEntity{} + + // the entity is not array, it will be ignored + _, _ = o.InsertMulti(3, entity) + assert.Equal(t, 0, entity.Age) + assert.Equal(t, 0, entity.AgeInOldStyle) + + _, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity}) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + +} + +type defaultValueTestOrm struct { + orm.DoNothingOrm +} + +type DefaultValueTestEntity struct { + Id int`orm:pk` + Age int `default:"12"` + AgeInOldStyle int `orm:"default(13);bee()"` + AgeIgnore int +} diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go index 1c9fee09..586ec573 100644 --- a/pkg/orm/invocation.go +++ b/pkg/orm/invocation.go @@ -46,3 +46,12 @@ func (inv *Invocation) GetTableName() string { func (inv *Invocation) execute() { inv.f() } + +// GetPkFieldName return the primary key of this table +// if not found, "" is returned +func (inv *Invocation) GetPkFieldName() string { + if inv.mi.fields.pk != nil { + return inv.mi.fields.pk.name + } + return "" +} From 139c393f08b0c69b1aa6cf5ef932c5f05f2b56b3 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Fri, 14 Aug 2020 09:59:11 +0800 Subject: [PATCH 094/207] add const ErrLastInsertIdUnavailable --- pkg/orm/db.go | 4 ++-- pkg/orm/db_mysql.go | 2 +- pkg/orm/db_oracle.go | 2 +- pkg/orm/orm.go | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/orm/db.go b/pkg/orm/db.go index dea8845f..2477b132 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -488,7 +488,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println("[WARN] return LastInsertId error:", err) + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) } return lastInsertId, nil } @@ -594,7 +594,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println("[WARN] return LastInsertId error:", err) + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) } return lastInsertId, nil } diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go index 0db26e5b..11665fb2 100644 --- a/pkg/orm/db_mysql.go +++ b/pkg/orm/db_mysql.go @@ -167,7 +167,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println("[WARN] return LastInsertId error:", err) + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) } return lastInsertId, nil } diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index 91a30f81..5177fb89 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -153,7 +153,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println("[WARN] return LastInsertId error:", err) + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) } return lastInsertId, nil } diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 895636b6..5d81c764 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -85,6 +85,8 @@ var ( ErrStmtClosed = errors.New(" stmt already closed") ErrArgs = errors.New(" args error may be empty") ErrNotImplement = errors.New("have not implement") + + ErrLastInsertIdUnavailable = errors.New(" last insert id is unavailable") ) // Params stores the Params From 739b8bab0c0cbeef3f7ef23300eee506fd3f26f2 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Fri, 14 Aug 2020 10:31:08 +0800 Subject: [PATCH 095/207] fix UT --- pkg/orm/orm_test.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index d4a47686..0d4451cd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2538,7 +2538,7 @@ func TestStrPkInsert(t *testing.T) { var err error _, err = dORM.Insert(strPk) throwFailNow(t, AssertIs(err, nil)) - + var vForTesting StrPk err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting) throwFailNow(t, AssertIs(err, nil)) @@ -2549,12 +2549,19 @@ func TestStrPkInsert(t *testing.T) { Id: pk, Value: value2, } + _, err = dORM.InsertOrUpdate(strPkForUpsert, `id`) - throwFailNow(t, AssertIs(err, nil)) - - var vForTesting2 StrPk - err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting2) - throwFailNow(t, AssertIs(err, nil)) - throwFailNow(t, AssertIs(vForTesting2.Value, value2)) + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + var vForTesting2 StrPk + err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting2) + throwFailNow(t, AssertIs(err, nil)) + throwFailNow(t, AssertIs(vForTesting2.Value, value2)) + } } From 7b899aa9af47abe594a8bee1aa0db625310c2022 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Fri, 14 Aug 2020 15:09:47 +0800 Subject: [PATCH 096/207] add ErrLastInsertIdUnavailable --- pkg/orm/db.go | 8 ++++++-- pkg/orm/db_mysql.go | 4 +++- pkg/orm/db_oracle.go | 4 +++- pkg/orm/orm_test.go | 5 ++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pkg/orm/db.go b/pkg/orm/db.go index 2477b132..0b6d8ac1 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -489,8 +489,10 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s lastInsertId, err := res.LastInsertId() if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + }else{ + return lastInsertId, nil } - return lastInsertId, nil } return 0, err } @@ -595,8 +597,10 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a lastInsertId, err := res.LastInsertId() if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + }else{ + return lastInsertId, nil } - return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go index 11665fb2..efa5a50b 100644 --- a/pkg/orm/db_mysql.go +++ b/pkg/orm/db_mysql.go @@ -168,8 +168,10 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val lastInsertId, err := res.LastInsertId() if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + }else{ + return lastInsertId, nil } - return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index 5177fb89..d384d33e 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -154,8 +154,10 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam lastInsertId, err := res.LastInsertId() if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + }else{ + return lastInsertId, nil } - return lastInsertId, nil } return 0, err } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 0d4451cd..c759309e 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2537,7 +2537,9 @@ func TestStrPkInsert(t *testing.T) { var err error _, err = dORM.Insert(strPk) - throwFailNow(t, AssertIs(err, nil)) + if err != ErrLastInsertIdUnavailable { + throwFailNow(t, AssertIs(err, nil)) + } var vForTesting StrPk err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting) @@ -2554,6 +2556,7 @@ func TestStrPkInsert(t *testing.T) { if err != nil { fmt.Println(err) if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else if err == ErrLastInsertIdUnavailable { } else { throwFailNow(t, err) } From 7442919f5ace987ea5cc2970ca82905698662290 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Fri, 14 Aug 2020 10:57:25 +0800 Subject: [PATCH 097/207] fix issue #3776 --- pkg/orm/models_test.go | 20 ++++++++++++++------ pkg/orm/orm_raw.go | 21 +++++++++++++++++++-- pkg/orm/orm_test.go | 15 ++++++++++++++- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index b53da05b..7fba89b1 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -54,18 +54,24 @@ func (e *SliceStringField) FieldType() int { } func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") + f := func(str string) { + if len(str) > 0 { + parts := strings.Split(str, ",") v := make([]string, 0, len(parts)) for _, p := range parts { v = append(v, strings.TrimSpace(p)) } e.Set(v) } + } + + switch d := value.(type) { + case []string: + e.Set(d) + case string: + f(d) + case []byte: + f(string(d)) default: return fmt.Errorf(" unknown value `%v`", value) } @@ -97,6 +103,8 @@ func (e *JSONFieldTest) SetRaw(value interface{}) error { switch d := value.(type) { case string: return json.Unmarshal([]byte(d), e) + case []byte: + return json.Unmarshal(d, e) default: return fmt.Errorf(" unknown value `%v`", value) } diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 2f214f93..687f7099 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -17,6 +17,7 @@ package orm import ( "database/sql" "fmt" + "github.com/pkg/errors" "reflect" "time" ) @@ -369,7 +370,15 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { field.Set(mf) field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) } - o.setFieldValue(field, value) + if fi.isFielder { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + return errors.Errorf("set raw error:%s", err) + } + } else { + o.setFieldValue(field, value) + } } } } else { @@ -510,7 +519,15 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { field.Set(mf) field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) } - o.setFieldValue(field, value) + if fi.isFielder { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + return 0, errors.Errorf("set raw error:%s", err) + } + } else { + o.setFieldValue(field, value) + } } } } else { diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index fd752a94..6e65ad7a 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -777,6 +777,20 @@ func TestCustomField(t *testing.T) { throwFailNow(t, AssertIs(user.Extra.Name, "beego")) throwFailNow(t, AssertIs(user.Extra.Data, "orm")) + + var users []User + Q := dDbBaser.TableQuote() + n, err := dORM.Raw(fmt.Sprintf("SELECT * FROM %suser%s where id=?", Q, Q), 2).QueryRows(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(n, 1)) + throwFailNow(t, AssertIs(users[0].Extra.Name, "beego")) + throwFailNow(t, AssertIs(users[0].Extra.Data, "orm")) + + user = User{} + err = dORM.Raw(fmt.Sprintf("SELECT * FROM %suser%s where id=?", Q, Q), 2).QueryRow(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) } func TestExpr(t *testing.T) { @@ -2543,4 +2557,3 @@ func TestStrPkInsert(t *testing.T) { throwFailNow(t, AssertIs(err, nil)) throwFailNow(t, AssertIs(vForTesting.Value, value)) } - From b4a85c8f13e0ac54cac93dfde144aa000f1540be Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 15 Aug 2020 09:39:58 +0800 Subject: [PATCH 098/207] Remove files --- .travis.yml | 2 +- admin.go | 461 --- admin_test.go | 249 -- adminui.go | 356 --- app.go | 516 ---- beego.go | 132 - build_info.go | 34 - cache/README.md | 59 - cache/cache.go | 103 - cache/cache_test.go | 191 -- cache/conv.go | 100 - cache/conv_test.go | 143 - cache/file.go | 258 -- cache/memcache/memcache.go | 188 -- cache/memcache/memcache_test.go | 108 - cache/memory.go | 256 -- cache/redis/redis.go | 272 -- cache/redis/redis_test.go | 144 - cache/ssdb/ssdb.go | 231 -- cache/ssdb/ssdb_test.go | 104 - config.go | 557 ---- config/config.go | 259 -- config/config_test.go | 55 - config/env/env.go | 92 - config/env/env_test.go | 75 - config/fake.go | 151 - config/ini.go | 524 ---- config/ini_test.go | 190 -- config/json.go | 290 -- config/json_test.go | 222 -- config/xml/xml.go | 248 -- config/xml/xml_test.go | 125 - config/yaml/yaml.go | 337 --- config/yaml/yaml_test.go | 115 - config_test.go | 146 - context/acceptencoder.go | 232 -- context/acceptencoder_test.go | 59 - context/context.go | 263 -- context/context_test.go | 54 - context/input.go | 709 ----- context/input_test.go | 217 -- context/output.go | 413 --- context/param/conv.go | 78 - context/param/methodparams.go | 69 - context/param/options.go | 37 - context/param/parsers.go | 149 - context/param/parsers_test.go | 84 - context/renderer.go | 12 - context/response.go | 27 - controller.go | 774 ----- controller_test.go | 181 -- doc.go | 19 - error.go | 492 ---- error_test.go | 88 - filter.go | 47 - filter_test.go | 68 - flash.go | 119 - flash_test.go | 54 - fs.go | 77 - go.mod | 1 + go.sum | 26 + grace/grace.go | 175 -- grace/server.go | 362 --- hooks.go | 104 - httplib/README.md | 97 - httplib/httplib.go | 697 ----- httplib/httplib_test.go | 286 -- logs/README.md | 72 - logs/accesslog.go | 83 - logs/alils/alils.go | 186 -- logs/alils/config.go | 13 - logs/alils/log.pb.go | 1038 ------- logs/alils/log_config.go | 42 - logs/alils/log_project.go | 819 ------ logs/alils/log_store.go | 271 -- logs/alils/machine_group.go | 91 - logs/alils/request.go | 62 - logs/alils/signature.go | 111 - logs/conn.go | 119 - logs/conn_test.go | 80 - logs/console.go | 99 - logs/console_test.go | 64 - logs/es/es.go | 102 - logs/file.go | 409 --- logs/file_test.go | 420 --- logs/jianliao.go | 72 - logs/log.go | 669 ----- logs/logger.go | 176 -- logs/logger_test.go | 57 - logs/multifile.go | 119 - logs/multifile_test.go | 78 - logs/slack.go | 60 - logs/smtp.go | 149 - logs/smtp_test.go | 24 - metric/prometheus.go | 101 - metric/prometheus_test.go | 42 - migration/ddl.go | 395 --- migration/doc.go | 32 - migration/migration.go | 330 --- mime.go | 556 ---- namespace.go | 433 --- namespace_test.go | 168 -- orm/README.md | 159 -- orm/cmd.go | 283 -- orm/cmd_utils.go | 320 --- orm/db.go | 1908 ------------- orm/db_alias.go | 487 ---- orm/db_mysql.go | 189 -- orm/db_oracle.go | 142 - orm/db_postgres.go | 195 -- orm/db_sqlite.go | 167 -- orm/db_tables.go | 482 ---- orm/db_tidb.go | 63 - orm/db_utils.go | 177 -- orm/models.go | 99 - orm/models_boot.go | 347 --- orm/models_fields.go | 783 ------ orm/models_info_f.go | 473 ---- orm/models_info_m.go | 148 - orm/models_test.go | 497 ---- orm/models_utils.go | 227 -- orm/orm.go | 602 ---- orm/orm_conds.go | 153 - orm/orm_log.go | 222 -- orm/orm_object.go | 87 - orm/orm_querym2m.go | 140 - orm/orm_queryset.go | 300 -- orm/orm_raw.go | 867 ------ orm/orm_test.go | 2500 ----------------- orm/qb.go | 62 - orm/qb_mysql.go | 185 -- orm/qb_tidb.go | 182 -- orm/types.go | 474 ---- orm/utils.go | 319 --- orm/utils_test.go | 70 - parser.go | 591 ---- plugins/apiauth/apiauth.go | 165 -- plugins/apiauth/apiauth_test.go | 20 - plugins/auth/basic.go | 107 - plugins/authz/authz.go | 86 - plugins/authz/authz_model.conf | 14 - plugins/authz/authz_policy.csv | 7 - plugins/authz/authz_test.go | 107 - plugins/cors/cors.go | 228 -- plugins/cors/cors_test.go | 253 -- policy.go | 100 - router.go | 1085 ------- router_test.go | 732 ----- session/README.md | 114 - session/couchbase/sess_couchbase.go | 247 -- session/ledis/ledis_session.go | 173 -- session/memcache/sess_memcache.go | 230 -- session/mysql/sess_mysql.go | 228 -- session/postgres/sess_postgresql.go | 243 -- session/redis/sess_redis.go | 261 -- session/redis_cluster/redis_cluster.go | 221 -- session/redis_sentinel/sess_redis_sentinel.go | 234 -- .../sess_redis_sentinel_test.go | 90 - session/sess_cookie.go | 180 -- session/sess_cookie_test.go | 105 - session/sess_file.go | 315 --- session/sess_file_test.go | 386 --- session/sess_mem.go | 196 -- session/sess_mem_test.go | 58 - session/sess_test.go | 131 - session/sess_utils.go | 207 -- session/session.go | 377 --- session/ssdb/sess_ssdb.go | 199 -- staticfile.go | 234 -- staticfile_test.go | 99 - swagger/swagger.go | 174 -- template.go | 417 --- template_test.go | 316 --- templatefunc.go | 798 ------ templatefunc_test.go | 380 --- testing/assertions.go | 15 - testing/client.go | 65 - toolbox/healthcheck.go | 48 - toolbox/profile.go | 184 -- toolbox/profile_test.go | 28 - toolbox/statistics.go | 149 - toolbox/statistics_test.go | 40 - toolbox/task.go | 640 ----- toolbox/task_test.go | 85 - tree.go | 590 ---- tree_test.go | 306 -- unregroute_test.go | 226 -- utils/caller.go | 25 - utils/caller_test.go | 28 - utils/captcha/LICENSE | 19 - utils/captcha/README.md | 45 - utils/captcha/captcha.go | 270 -- utils/captcha/image.go | 501 ---- utils/captcha/image_test.go | 52 - utils/captcha/siprng.go | 277 -- utils/captcha/siprng_test.go | 33 - utils/debug.go | 478 ---- utils/debug_test.go | 46 - utils/file.go | 101 - utils/file_test.go | 76 - utils/mail.go | 424 --- utils/mail_test.go | 41 - utils/pagination/controller.go | 26 - utils/pagination/doc.go | 58 - utils/pagination/paginator.go | 189 -- utils/pagination/utils.go | 34 - utils/rand.go | 44 - utils/rand_test.go | 33 - utils/safemap.go | 91 - utils/safemap_test.go | 89 - utils/slice.go | 170 -- utils/slice_test.go | 29 - utils/testdata/grepe.test | 7 - utils/utils.go | 89 - utils/utils_test.go | 36 - validation/README.md | 147 - validation/util.go | 298 -- validation/util_test.go | 128 - validation/validation.go | 456 --- validation/validation_test.go | 609 ---- validation/validators.go | 738 ----- 221 files changed, 28 insertions(+), 52357 deletions(-) delete mode 100644 admin.go delete mode 100644 admin_test.go delete mode 100644 adminui.go delete mode 100644 app.go delete mode 100644 beego.go delete mode 100644 build_info.go delete mode 100644 cache/README.md delete mode 100644 cache/cache.go delete mode 100644 cache/cache_test.go delete mode 100644 cache/conv.go delete mode 100644 cache/conv_test.go delete mode 100644 cache/file.go delete mode 100644 cache/memcache/memcache.go delete mode 100644 cache/memcache/memcache_test.go delete mode 100644 cache/memory.go delete mode 100644 cache/redis/redis.go delete mode 100644 cache/redis/redis_test.go delete mode 100644 cache/ssdb/ssdb.go delete mode 100644 cache/ssdb/ssdb_test.go delete mode 100644 config.go delete mode 100644 config/config.go delete mode 100644 config/config_test.go delete mode 100644 config/env/env.go delete mode 100644 config/env/env_test.go delete mode 100644 config/fake.go delete mode 100644 config/ini.go delete mode 100644 config/ini_test.go delete mode 100644 config/json.go delete mode 100644 config/json_test.go delete mode 100644 config/xml/xml.go delete mode 100644 config/xml/xml_test.go delete mode 100644 config/yaml/yaml.go delete mode 100644 config/yaml/yaml_test.go delete mode 100644 config_test.go delete mode 100644 context/acceptencoder.go delete mode 100644 context/acceptencoder_test.go delete mode 100644 context/context.go delete mode 100644 context/context_test.go delete mode 100644 context/input.go delete mode 100644 context/input_test.go delete mode 100644 context/output.go delete mode 100644 context/param/conv.go delete mode 100644 context/param/methodparams.go delete mode 100644 context/param/options.go delete mode 100644 context/param/parsers.go delete mode 100644 context/param/parsers_test.go delete mode 100644 context/renderer.go delete mode 100644 context/response.go delete mode 100644 controller.go delete mode 100644 controller_test.go delete mode 100644 doc.go delete mode 100644 error.go delete mode 100644 error_test.go delete mode 100644 filter.go delete mode 100644 filter_test.go delete mode 100644 flash.go delete mode 100644 flash_test.go delete mode 100644 fs.go delete mode 100644 grace/grace.go delete mode 100644 grace/server.go delete mode 100644 hooks.go delete mode 100644 httplib/README.md delete mode 100644 httplib/httplib.go delete mode 100644 httplib/httplib_test.go delete mode 100644 logs/README.md delete mode 100644 logs/accesslog.go delete mode 100644 logs/alils/alils.go delete mode 100755 logs/alils/config.go delete mode 100755 logs/alils/log.pb.go delete mode 100755 logs/alils/log_config.go delete mode 100755 logs/alils/log_project.go delete mode 100755 logs/alils/log_store.go delete mode 100755 logs/alils/machine_group.go delete mode 100755 logs/alils/request.go delete mode 100755 logs/alils/signature.go delete mode 100644 logs/conn.go delete mode 100644 logs/conn_test.go delete mode 100644 logs/console.go delete mode 100644 logs/console_test.go delete mode 100644 logs/es/es.go delete mode 100644 logs/file.go delete mode 100644 logs/file_test.go delete mode 100644 logs/jianliao.go delete mode 100644 logs/log.go delete mode 100644 logs/logger.go delete mode 100644 logs/logger_test.go delete mode 100644 logs/multifile.go delete mode 100644 logs/multifile_test.go delete mode 100644 logs/slack.go delete mode 100644 logs/smtp.go delete mode 100644 logs/smtp_test.go delete mode 100644 metric/prometheus.go delete mode 100644 metric/prometheus_test.go delete mode 100644 migration/ddl.go delete mode 100644 migration/doc.go delete mode 100644 migration/migration.go delete mode 100644 mime.go delete mode 100644 namespace.go delete mode 100644 namespace_test.go delete mode 100644 orm/README.md delete mode 100644 orm/cmd.go delete mode 100644 orm/cmd_utils.go delete mode 100644 orm/db.go delete mode 100644 orm/db_alias.go delete mode 100644 orm/db_mysql.go delete mode 100644 orm/db_oracle.go delete mode 100644 orm/db_postgres.go delete mode 100644 orm/db_sqlite.go delete mode 100644 orm/db_tables.go delete mode 100644 orm/db_tidb.go delete mode 100644 orm/db_utils.go delete mode 100644 orm/models.go delete mode 100644 orm/models_boot.go delete mode 100644 orm/models_fields.go delete mode 100644 orm/models_info_f.go delete mode 100644 orm/models_info_m.go delete mode 100644 orm/models_test.go delete mode 100644 orm/models_utils.go delete mode 100644 orm/orm.go delete mode 100644 orm/orm_conds.go delete mode 100644 orm/orm_log.go delete mode 100644 orm/orm_object.go delete mode 100644 orm/orm_querym2m.go delete mode 100644 orm/orm_queryset.go delete mode 100644 orm/orm_raw.go delete mode 100644 orm/orm_test.go delete mode 100644 orm/qb.go delete mode 100644 orm/qb_mysql.go delete mode 100644 orm/qb_tidb.go delete mode 100644 orm/types.go delete mode 100644 orm/utils.go delete mode 100644 orm/utils_test.go delete mode 100644 parser.go delete mode 100644 plugins/apiauth/apiauth.go delete mode 100644 plugins/apiauth/apiauth_test.go delete mode 100644 plugins/auth/basic.go delete mode 100644 plugins/authz/authz.go delete mode 100644 plugins/authz/authz_model.conf delete mode 100644 plugins/authz/authz_policy.csv delete mode 100644 plugins/authz/authz_test.go delete mode 100644 plugins/cors/cors.go delete mode 100644 plugins/cors/cors_test.go delete mode 100644 policy.go delete mode 100644 router.go delete mode 100644 router_test.go delete mode 100644 session/README.md delete mode 100644 session/couchbase/sess_couchbase.go delete mode 100644 session/ledis/ledis_session.go delete mode 100644 session/memcache/sess_memcache.go delete mode 100644 session/mysql/sess_mysql.go delete mode 100644 session/postgres/sess_postgresql.go delete mode 100644 session/redis/sess_redis.go delete mode 100644 session/redis_cluster/redis_cluster.go delete mode 100644 session/redis_sentinel/sess_redis_sentinel.go delete mode 100644 session/redis_sentinel/sess_redis_sentinel_test.go delete mode 100644 session/sess_cookie.go delete mode 100644 session/sess_cookie_test.go delete mode 100644 session/sess_file.go delete mode 100644 session/sess_file_test.go delete mode 100644 session/sess_mem.go delete mode 100644 session/sess_mem_test.go delete mode 100644 session/sess_test.go delete mode 100644 session/sess_utils.go delete mode 100644 session/session.go delete mode 100644 session/ssdb/sess_ssdb.go delete mode 100644 staticfile.go delete mode 100644 staticfile_test.go delete mode 100644 swagger/swagger.go delete mode 100644 template.go delete mode 100644 template_test.go delete mode 100644 templatefunc.go delete mode 100644 templatefunc_test.go delete mode 100644 testing/assertions.go delete mode 100644 testing/client.go delete mode 100644 toolbox/healthcheck.go delete mode 100644 toolbox/profile.go delete mode 100644 toolbox/profile_test.go delete mode 100644 toolbox/statistics.go delete mode 100644 toolbox/statistics_test.go delete mode 100644 toolbox/task.go delete mode 100644 toolbox/task_test.go delete mode 100644 tree.go delete mode 100644 tree_test.go delete mode 100644 unregroute_test.go delete mode 100644 utils/caller.go delete mode 100644 utils/caller_test.go delete mode 100644 utils/captcha/LICENSE delete mode 100644 utils/captcha/README.md delete mode 100644 utils/captcha/captcha.go delete mode 100644 utils/captcha/image.go delete mode 100644 utils/captcha/image_test.go delete mode 100644 utils/captcha/siprng.go delete mode 100644 utils/captcha/siprng_test.go delete mode 100644 utils/debug.go delete mode 100644 utils/debug_test.go delete mode 100644 utils/file.go delete mode 100644 utils/file_test.go delete mode 100644 utils/mail.go delete mode 100644 utils/mail_test.go delete mode 100644 utils/pagination/controller.go delete mode 100644 utils/pagination/doc.go delete mode 100644 utils/pagination/paginator.go delete mode 100644 utils/pagination/utils.go delete mode 100644 utils/rand.go delete mode 100644 utils/rand_test.go delete mode 100644 utils/safemap.go delete mode 100644 utils/safemap_test.go delete mode 100644 utils/slice.go delete mode 100644 utils/slice_test.go delete mode 100644 utils/testdata/grepe.test delete mode 100644 utils/utils.go delete mode 100644 utils/utils_test.go delete mode 100644 validation/README.md delete mode 100644 validation/util.go delete mode 100644 validation/util_test.go delete mode 100644 validation/validation.go delete mode 100644 validation/validation_test.go delete mode 100644 validation/validators.go diff --git a/.travis.yml b/.travis.yml index 26c3732e..63b31c52 100644 --- a/.travis.yml +++ b/.travis.yml @@ -64,7 +64,7 @@ after_script: - rm -rf ./res/var/* script: - go test ./... - - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" + - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./pkg - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s diff --git a/admin.go b/admin.go deleted file mode 100644 index d9c96dfd..00000000 --- a/admin.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "os" - "reflect" - "strconv" - "text/template" - "time" - - "github.com/prometheus/client_golang/prometheus/promhttp" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// BeeAdminApp is the default adminApp used by admin module. -var beeAdminApp *adminApp - -// FilterMonitorFunc is default monitor filter when admin module is enable. -// if this func returns, admin module records qps for this request by condition of this function logic. -// usage: -// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { -// if method == "POST" { -// return false -// } -// if t.Nanoseconds() < 100 { -// return false -// } -// if strings.HasPrefix(requestPath, "/astaxie") { -// return false -// } -// return true -// } -// beego.FilterMonitorFunc = MyFilterMonitor. -// Deprecated: using pkg/, we will delete this in v2.1.0 -var FilterMonitorFunc func(string, string, time.Duration, string, int) bool - -func init() { - beeAdminApp = &adminApp{ - routers: make(map[string]http.HandlerFunc), - } - // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Route("/", adminIndex) - beeAdminApp.Route("/qps", qpsIndex) - beeAdminApp.Route("/prof", profIndex) - beeAdminApp.Route("/healthcheck", healthcheck) - beeAdminApp.Route("/task", taskStatus) - beeAdminApp.Route("/listconf", listConf) - beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) - FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } -} - -// AdminIndex is the default http.Handler for admin module. -// it matches url pattern "/". -func adminIndex(rw http.ResponseWriter, _ *http.Request) { - writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) -} - -// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. -// it's registered with url pattern "/qps" in admin module. -func qpsIndex(rw http.ResponseWriter, _ *http.Request) { - data := make(map[interface{}]interface{}) - data["Content"] = toolbox.StatisticsMap.GetMap() - - // do html escape before display path, avoid xss - if content, ok := (data["Content"]).(M); ok { - if resultLists, ok := (content["Data"]).([][]string); ok { - for i := range resultLists { - if len(resultLists[i]) > 0 { - resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) - } - } - } - } - - writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) -} - -// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. -// it's registered with url pattern "/listconf" in admin module. -func listConf(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - rw.Write([]byte("command not support")) - return - } - - data := make(map[interface{}]interface{}) - switch command { - case "conf": - m := make(M) - list("BConfig", BConfig, m) - m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) - m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(configTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - - data["Content"] = m - - tmpl.Execute(rw, data) - - case "router": - content := PrintTree() - content["Fields"] = []string{ - "Router Pattern", - "Methods", - "Controller", - } - data["Content"] = content - data["Title"] = "Routers" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - case "filter": - var ( - content = M{ - "Fields": []string{ - "Router Pattern", - "Filter Function", - }, - } - filterTypes = []string{} - filterTypeData = make(M) - ) - - if BeeApp.Handlers.enableFilter { - var filterType string - for k, fr := range map[int]string{ - BeforeStatic: "Before Static", - BeforeRouter: "Before Router", - BeforeExec: "Before Exec", - AfterExec: "After Exec", - FinishRouter: "Finish Router"} { - if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { - filterType = fr - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - var result = []string{ - // void xss - template.HTMLEscapeString(f.pattern), - template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - } - } - - content["Data"] = filterTypeData - content["Methods"] = filterTypes - - data["Content"] = content - data["Title"] = "Filters" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - default: - rw.Write([]byte("command not support")) - } -} - -func list(root string, p interface{}, m M) { - pt := reflect.TypeOf(p) - pv := reflect.ValueOf(p) - if pt.Kind() == reflect.Ptr { - pt = pt.Elem() - pv = pv.Elem() - } - for i := 0; i < pv.NumField(); i++ { - var key string - if root == "" { - key = pt.Field(i).Name - } else { - key = root + "." + pt.Field(i).Name - } - if pv.Field(i).Kind() == reflect.Struct { - list(key, pv.Field(i).Interface(), m) - } else { - m[key] = pv.Field(i).Interface() - } - } -} - -// PrintTree prints all registered routers. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func PrintTree() M { - var ( - content = M{} - methods = []string{} - methodsData = make(M) - ) - for method, t := range BeeApp.Handlers.routers { - - resultList := new([][]string) - - printTree(resultList, t) - - methods = append(methods, template.HTMLEscapeString(method)) - methodsData[template.HTMLEscapeString(method)] = resultList - } - - content["Data"] = methodsData - content["Methods"] = methods - return content -} - -func printTree(resultList *[][]string, t *Tree) { - for _, tr := range t.fixrouters { - printTree(resultList, tr) - } - if t.wildcard != nil { - printTree(resultList, t.wildcard) - } - for _, l := range t.leaves { - if v, ok := l.runObject.(*ControllerInfo); ok { - if v.routerType == routerTypeBeego { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - template.HTMLEscapeString(v.controllerType.String()), - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeRESTFul { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - "", - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeHandler { - var result = []string{ - template.HTMLEscapeString(v.pattern), - "", - "", - } - *resultList = append(*resultList, result) - } - } - } -} - -// ProfIndex is a http.Handler for showing profile command. -// it's in url pattern "/prof" in admin module. -func profIndex(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - return - } - - var ( - format = r.Form.Get("format") - data = make(map[interface{}]interface{}) - result bytes.Buffer - ) - toolbox.ProcessInput(command, &result) - data["Content"] = template.HTMLEscapeString(result.String()) - - if format == "json" && command == "gc summary" { - dataJSON, err := json.Marshal(data) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(rw, dataJSON) - return - } - - data["Title"] = template.HTMLEscapeString(command) - defaultTpl := defaultScriptsTpl - if command == "gc summary" { - defaultTpl = gcAjaxTpl - } - writeTemplate(rw, data, profillingTpl, defaultTpl) -} - -// Healthcheck is a http.Handler calling health checking and showing the result. -// it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, r *http.Request) { - var ( - result []string - data = make(map[interface{}]interface{}) - resultList = new([][]string) - content = M{ - "Fields": []string{"Name", "Message", "Status"}, - } - ) - - for name, h := range toolbox.AdminCheckList { - if err := h.Check(); err != nil { - result = []string{ - "error", - template.HTMLEscapeString(name), - template.HTMLEscapeString(err.Error()), - } - } else { - result = []string{ - "success", - template.HTMLEscapeString(name), - "OK", - } - } - *resultList = append(*resultList, result) - } - - queryParams := r.URL.Query() - jsonFlag := queryParams.Get("json") - shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) - - if shouldReturnJSON { - response := buildHealthCheckResponseList(resultList) - jsonResponse, err := json.Marshal(response) - - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } else { - writeJSON(rw, jsonResponse) - } - return - } - - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Health Check" - - writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) -} - -func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { - response := make([]map[string]interface{}, len(*healthCheckResults)) - - for i, healthCheckResult := range *healthCheckResults { - currentResultMap := make(map[string]interface{}) - - currentResultMap["name"] = healthCheckResult[0] - currentResultMap["message"] = healthCheckResult[1] - currentResultMap["status"] = healthCheckResult[2] - - response[i] = currentResultMap - } - - return response - -} - -func writeJSON(rw http.ResponseWriter, jsonData []byte) { - rw.Header().Set("Content-Type", "application/json") - rw.Write(jsonData) -} - -// TaskStatus is a http.Handler with running task status (task name, status and the last execution). -// it's in "/task" pattern in admin module. -func taskStatus(rw http.ResponseWriter, req *http.Request) { - data := make(map[interface{}]interface{}) - - // Run Task - req.ParseForm() - taskname := req.Form.Get("taskname") - if taskname != "" { - if t, ok := toolbox.AdminTaskList[taskname]; ok { - if err := t.Run(); err != nil { - data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} - } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} - } else { - data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} - } - } - - // List Tasks - content := make(M) - resultList := new([][]string) - var fields = []string{ - "Task Name", - "Task Spec", - "Task Status", - "Last Time", - "", - } - for tname, tk := range toolbox.AdminTaskList { - result := []string{ - template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec()), - template.HTMLEscapeString(tk.GetStatus()), - template.HTMLEscapeString(tk.GetPrev().String()), - } - *resultList = append(*resultList, result) - } - - content["Fields"] = fields - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Tasks" - writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) -} - -func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - for _, tpl := range tpls { - tmpl = template.Must(tmpl.Parse(tpl)) - } - tmpl.Execute(rw, data) -} - -// adminApp is an http.HandlerFunc map used as beeAdminApp. -type adminApp struct { - routers map[string]http.HandlerFunc -} - -// Route adds http.HandlerFunc to adminApp with url pattern. -func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { - admin.routers[pattern] = f -} - -// Run adminApp http server. -// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (admin *adminApp) Run() { - if len(toolbox.AdminTaskList) > 0 { - toolbox.StartTask() - } - addr := BConfig.Listen.AdminAddr - - if BConfig.Listen.AdminPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) - } - for p, f := range admin.routers { - http.Handle(p, f) - } - logs.Info("Admin server Running on %s", addr) - - var err error - if BConfig.Listen.Graceful { - err = grace.ListenAndServe(addr, nil) - } else { - err = http.ListenAndServe(addr, nil) - } - if err != nil { - logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - } -} diff --git a/admin_test.go b/admin_test.go deleted file mode 100644 index 205c76c2..00000000 --- a/admin_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package beego - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/astaxie/beego/toolbox" -) - -type SampleDatabaseCheck struct { -} - -type SampleCacheCheck struct { -} - -func (dc *SampleDatabaseCheck) Check() error { - return nil -} - -func (cc *SampleCacheCheck) Check() error { - return errors.New("no cache detected") -} - -func TestList_01(t *testing.T) { - m := make(M) - list("BConfig", BConfig, m) - t.Log(m) - om := oldMap() - for k, v := range om { - if fmt.Sprint(m[k]) != fmt.Sprint(v) { - t.Log(k, "old-key", v, "new-key", m[k]) - t.FailNow() - } - } -} - -func oldMap() M { - m := make(M) - m["BConfig.AppName"] = BConfig.AppName - m["BConfig.RunMode"] = BConfig.RunMode - m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive - m["BConfig.ServerName"] = BConfig.ServerName - m["BConfig.RecoverPanic"] = BConfig.RecoverPanic - m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody - m["BConfig.EnableGzip"] = BConfig.EnableGzip - m["BConfig.MaxMemory"] = BConfig.MaxMemory - m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow - m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful - m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut - m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 - m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP - m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr - m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort - m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS - m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr - m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort - m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile - m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile - m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin - m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr - m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort - m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi - m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo - m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender - m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs - m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName - m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator - m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex - m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir - m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip - m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize - m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum - m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft - m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight - m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath - m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF - m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire - m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn - m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider - m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName - m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime - m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig - m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime - m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie - m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain - m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly - m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs - m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs - m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat - m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum - m["BConfig.Log.Outputs"] = BConfig.Log.Outputs - return m -} - -func TestWriteJSON(t *testing.T) { - t.Log("Testing the adding of JSON to the response") - - w := httptest.NewRecorder() - originalBody := []int{1, 2, 3} - - res, _ := json.Marshal(originalBody) - - writeJSON(w, res) - - decodedBody := []int{} - err := json.NewDecoder(w.Body).Decode(&decodedBody) - - if err != nil { - t.Fatal("Could not decode response body into slice.") - } - - for i := range decodedBody { - if decodedBody[i] != originalBody[i] { - t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) - } - } -} - -func TestHealthCheckHandlerDefault(t *testing.T) { - endpointPath := "/healthcheck" - - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) - - req, err := http.NewRequest("GET", endpointPath, nil) - if err != nil { - t.Fatal(err) - } - - w := httptest.NewRecorder() - - handler := http.HandlerFunc(healthcheck) - - handler.ServeHTTP(w, req) - - if status := w.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - if !strings.Contains(w.Body.String(), "database") { - t.Errorf("Expected 'database' in generated template.") - } - -} - -func TestBuildHealthCheckResponseList(t *testing.T) { - healthCheckResults := [][]string{ - []string{ - "error", - "Database", - "Error occured whie starting the db", - }, - []string{ - "success", - "Cache", - "Cache started successfully", - }, - } - - responseList := buildHealthCheckResponseList(&healthCheckResults) - - if len(responseList) != len(healthCheckResults) { - t.Errorf("invalid response map length: got %d want %d", - len(responseList), len(healthCheckResults)) - } - - responseFields := []string{"name", "message", "status"} - - for _, response := range responseList { - for _, field := range responseFields { - _, ok := response[field] - if !ok { - t.Errorf("expected %s to be in the response %v", field, response) - } - } - - } - -} - -func TestHealthCheckHandlerReturnsJSON(t *testing.T) { - - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) - - req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) - if err != nil { - t.Fatal(err) - } - - w := httptest.NewRecorder() - - handler := http.HandlerFunc(healthcheck) - - handler.ServeHTTP(w, req) - if status := w.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - decodedResponseBody := []map[string]interface{}{} - expectedResponseBody := []map[string]interface{}{} - - expectedJSONString := []byte(` - [ - { - "message":"database", - "name":"success", - "status":"OK" - }, - { - "message":"cache", - "name":"error", - "status":"no cache detected" - } - ] - `) - - json.Unmarshal(expectedJSONString, &expectedResponseBody) - - json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) - - if len(expectedResponseBody) != len(decodedResponseBody) { - t.Errorf("invalid response map length: got %d want %d", - len(decodedResponseBody), len(expectedResponseBody)) - } - assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) - assert.Equal(t, 2, len(decodedResponseBody)) - - var database, cache map[string]interface{} - if decodedResponseBody[0]["message"] == "database" { - database = decodedResponseBody[0] - cache = decodedResponseBody[1] - } else { - database = decodedResponseBody[1] - cache = decodedResponseBody[0] - } - - assert.Equal(t, expectedResponseBody[0], database) - assert.Equal(t, expectedResponseBody[1], cache) - -} diff --git a/adminui.go b/adminui.go deleted file mode 100644 index cdcdef33..00000000 --- a/adminui.go +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var indexTpl = ` -{{define "content"}} -

Beego Admin Dashboard

-

-For detail usage please check our document: -

-

-Toolbox -

-

-Live Monitor -

-{{.Content}} -{{end}}` - -var profillingTpl = ` -{{define "content"}} -

{{.Title}}

-
-
{{.Content}}
-
-{{end}}` - -var defaultScriptsTpl = `` - -var gcAjaxTpl = ` -{{define "scripts"}} - -{{end}} -` - -var qpsTpl = `{{define "content"}} -

Requests statistics

- - - - {{range .Content.Fields}} - - {{end}} - - - - - {{range $i, $elem := .Content.Data}} - - - - - - - - - - - {{end}} - - -
- {{.}} -
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
-{{end}}` - -var configTpl = ` -{{define "content"}} -

Configurations

-
-{{range $index, $elem := .Content}}
-{{$index}}={{$elem}}
-{{end}}
-
-{{end}} -` - -var routerAndFilterTpl = `{{define "content"}} - - -

{{.Title}}

- -{{range .Content.Methods}} - -
-
{{.}}
-
- - - - {{range $.Content.Fields}} - - {{end}} - - - - - {{$slice := index $.Content.Data .}} - {{range $i, $elem := $slice}} - - - {{range $elem}} - - {{end}} - - - {{end}} - - -
- {{.}} -
- {{.}} -
-
-
-{{end}} - - -{{end}}` - -var tasksTpl = `{{define "content"}} - -

{{.Title}}

- -{{if .Message }} -{{ $messageType := index .Message 0}} -

-{{index .Message 1}} -

-{{end}} - - - - - -{{range .Content.Fields}} - -{{end}} - - - - -{{range $i, $slice := .Content.Data}} - - {{range $slice}} - - {{end}} - - -{{end}} - -
-{{.}} -
- {{.}} - - Run -
- -{{end}}` - -var healthCheckTpl = ` -{{define "content"}} - -

{{.Title}}

- - - -{{range .Content.Fields}} - -{{end}} - - - -{{range $i, $slice := .Content.Data}} - {{ $header := index $slice 0}} - {{ if eq "success" $header}} - - {{else if eq "error" $header}} - - {{else}} - - {{end}} - {{range $j, $elem := $slice}} - {{if ne $j 0}} - - {{end}} - {{end}} - - -{{end}} - - -
- {{.}} -
- {{$elem}} - - {{$header}} -
-{{end}}` - -// The base dashboardTpl -var dashboardTpl = ` - - - - - - - - - - -Welcome to Beego Admin Dashboard - - - - - - - - - - - - - -
-{{template "content" .}} -
- - - - - - - -{{template "scripts" .}} - - -` diff --git a/app.go b/app.go deleted file mode 100644 index d86188c0..00000000 --- a/app.go +++ /dev/null @@ -1,516 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/fcgi" - "os" - "path" - "strings" - "time" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" - "golang.org/x/crypto/acme/autocert" -) - -var ( - // BeeApp is an application instance - // Deprecated: using pkg/, we will delete this in v2.1.0 - BeeApp *App -) - -func init() { - // create beego application - BeeApp = NewApp() -} - -// App defines beego application with a new PatternServeMux. -type App struct { - Handlers *ControllerRegister - Server *http.Server -} - -// NewApp returns a new beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewApp() *App { - cr := NewControllerRegister() - app := &App{Handlers: cr, Server: &http.Server{}} - return app -} - -// MiddleWare function for http.Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -type MiddleWare func(http.Handler) http.Handler - -// Run beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (app *App) Run(mws ...MiddleWare) { - addr := BConfig.Listen.HTTPAddr - - if BConfig.Listen.HTTPPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) - } - - var ( - err error - l net.Listener - endRunning = make(chan bool, 1) - ) - - // run cgi server - if BConfig.Listen.EnableFcgi { - if BConfig.Listen.EnableStdIo { - if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O - logs.Info("Use FCGI via standard I/O") - } else { - logs.Critical("Cannot use FCGI via standard I/O", err) - } - return - } - if BConfig.Listen.HTTPPort == 0 { - // remove the Socket file before start - if utils.FileExists(addr) { - os.Remove(addr) - } - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } - if err != nil { - logs.Critical("Listen: ", err) - } - if err = fcgi.Serve(l, app.Handlers); err != nil { - logs.Critical("fcgi.Serve: ", err) - } - return - } - - app.Server.Handler = app.Handlers - for i := len(mws) - 1; i >= 0; i-- { - if mws[i] == nil { - continue - } - app.Server.Handler = mws[i](app.Server.Handler) - } - app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.ErrorLog = logs.GetLogger("HTTP") - - // run graceful mode - if BConfig.Listen.Graceful { - httpsAddr := BConfig.Listen.HTTPSAddr - app.Server.Addr = httpsAddr - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - app.Server.Addr = httpsAddr - } - server := grace.NewServer(httpsAddr, app.Server.Handler) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - } else { - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } - if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - } - endRunning <- true - }() - } - if BConfig.Listen.EnableHTTP { - go func() { - server := grace.NewServer(addr, app.Server.Handler) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.ListenTCP4 { - server.Network = "tcp4" - } - if err := server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - endRunning <- true - }() - } - <-endRunning - return - } - - // run normal mode - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - } else if BConfig.Listen.EnableHTTP { - logs.Info("Start https server error, conflict with http. Please reset https port") - return - } - logs.Info("https server Running on https://%s", app.Server.Addr) - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } else if BConfig.Listen.EnableMutualHTTPS { - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) - if err != nil { - logs.Info("MutualHTTPS should provide TrustCaFile") - return - } - pool.AppendCertsFromPEM(data) - app.Server.TLSConfig = &tls.Config{ - ClientCAs: pool, - ClientAuth: BConfig.Listen.ClientAuth, - } - } - if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - - } - if BConfig.Listen.EnableHTTP { - go func() { - app.Server.Addr = addr - logs.Info("http server Running on http://%s", app.Server.Addr) - if BConfig.Listen.ListenTCP4 { - ln, err := net.Listen("tcp4", app.Server.Addr) - if err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - if err = app.Server.Serve(ln); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - } else { - if err := app.Server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } - <-endRunning -} - -// Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. -// usage: -// simple router -// beego.Router("/admin", &admin.UserController{}) -// beego.Router("/admin/index", &admin.ArticleController{}) -// -// regex router -// -// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) -// -// custom rules -// beego.Router("/api/list",&RestController{},"*:ListFood") -// beego.Router("/api/create",&RestController{},"post:CreateFood") -// beego.Router("/api/update",&RestController{},"put:UpdateFood") -// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - BeeApp.Handlers.Add(rootpath, c, mappingMethods...) - return BeeApp -} - -// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful -// in web applications that inherit most routes from a base webapp via the underscore -// import, and aim to overwrite only certain paths. -// The method parameter can be empty or "*" for all HTTP methods, or a particular -// method type (e.g. "GET" or "POST") for selective removal. -// -// Usage (replace "GET" with "*" for all methods): -// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") -// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func UnregisterFixedRoute(fixedRoute string, method string) *App { - subPaths := splitPath(fixedRoute) - if method == "" || method == "*" { - for m := range HTTPMETHOD { - if _, ok := BeeApp.Handlers.routers[m]; !ok { - continue - } - if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) - continue - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) - } - return BeeApp - } - // Single HTTP method - um := strings.ToUpper(method) - if _, ok := BeeApp.Handlers.routers[um]; ok { - if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) - return BeeApp - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) - } - return BeeApp -} - -func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { - for i := range entryPointTree.fixrouters { - if entryPointTree.fixrouters[i].prefix == paths[0] { - if len(paths) == 1 { - if len(entryPointTree.fixrouters[i].fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.fixrouters[i].leaves) > 0 { - entryPointTree.fixrouters[i].leaves[0] = nil - entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] - } - } else { - // Remove the *Tree from the fixrouters slice - entryPointTree.fixrouters[i] = nil - - if i == len(entryPointTree.fixrouters)-1 { - entryPointTree.fixrouters = entryPointTree.fixrouters[:i] - } else { - entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) - } - } - return - } - findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) - } - } -} - -func findAndRemoveSingleTree(entryPointTree *Tree) { - if entryPointTree == nil { - return - } - if len(entryPointTree.fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.leaves) > 0 { - entryPointTree.leaves[0] = nil - entryPointTree.leaves = entryPointTree.leaves[1:] - } - } -} - -// Include will generate router file in the router/xxx.go from the controller's comments -// usage: -// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// type BankAccount struct{ -// beego.Controller -// } -// -// register the function -// func (b *BankAccount)Mapping(){ -// b.Mapping("ShowAccount" , b.ShowAccount) -// b.Mapping("ModifyAccount", b.ModifyAccount) -//} -// -// //@router /account/:id [get] -// func (b *BankAccount) ShowAccount(){ -// //logic -// } -// -// -// //@router /account/:id [post] -// func (b *BankAccount) ModifyAccount(){ -// //logic -// } -// -// the comments @router url methodlist -// url support all the function Router's pattern -// methodlist [get post head put delete options *] -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Include(cList ...ControllerInterface) *App { - BeeApp.Handlers.Include(cList...) - return BeeApp -} - -// RESTRouter adds a restful controller handler to BeeApp. -// its' controller implements beego.ControllerInterface and -// defines a param "pattern/:objectId" to visit each resource. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RESTRouter(rootpath string, c ControllerInterface) *App { - Router(rootpath, c) - Router(path.Join(rootpath, ":objectId"), c) - return BeeApp -} - -// AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, -// visit the url /main/list to exec List function or /main/page to exec Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AutoRouter(c ControllerInterface) *App { - BeeApp.Handlers.AddAuto(c) - return BeeApp -} - -// AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, -// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AutoPrefix(prefix string, c ControllerInterface) *App { - BeeApp.Handlers.AddAutoPrefix(prefix, c) - return BeeApp -} - -// Get used to register router for Get method -// usage: -// beego.Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Get(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Get(rootpath, f) - return BeeApp -} - -// Post used to register router for Post method -// usage: -// beego.Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Post(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Post(rootpath, f) - return BeeApp -} - -// Delete used to register router for Delete method -// usage: -// beego.Delete("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Delete(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Delete(rootpath, f) - return BeeApp -} - -// Put used to register router for Put method -// usage: -// beego.Put("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Put(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Put(rootpath, f) - return BeeApp -} - -// Head used to register router for Head method -// usage: -// beego.Head("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Head(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Head(rootpath, f) - return BeeApp -} - -// Options used to register router for Options method -// usage: -// beego.Options("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Options(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Options(rootpath, f) - return BeeApp -} - -// Patch used to register router for Patch method -// usage: -// beego.Patch("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Patch(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Patch(rootpath, f) - return BeeApp -} - -// Any used to register router for all methods -// usage: -// beego.Any("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Any(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Any(rootpath, f) - return BeeApp -} - -// Handler used to register a Handler router -// usage: -// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { -// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) -// })) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Handler(rootpath string, h http.Handler, options ...interface{}) *App { - BeeApp.Handlers.Handler(rootpath, h, options...) - return BeeApp -} - -// InsertFilter adds a FilterFunc with pattern condition and action constant. -// The pos means action constant including -// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. -// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) - return BeeApp -} diff --git a/beego.go b/beego.go deleted file mode 100644 index ef93134d..00000000 --- a/beego.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "os" - "path/filepath" - "strconv" - "strings" -) - -const ( - // VERSION represent beego web framework version. - // Deprecated: using pkg/, we will delete this in v2.1.0 - VERSION = "1.12.2" - - // DEV is for develop - // Deprecated: using pkg/, we will delete this in v2.1.0 - DEV = "dev" - // PROD is for production - // Deprecated: using pkg/, we will delete this in v2.1.0 - PROD = "prod" -) - -// M is Map shortcut -// Deprecated: using pkg/, we will delete this in v2.1.0 -type M map[string]interface{} - -// Hook function to run -type hookfunc func() error - -var ( - hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc -) - -// AddAPPStartHook is used to register the hookfunc -// The hookfuncs will run in beego.Run() -// such as initiating session , starting middleware , building template, starting admin control and so on. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddAPPStartHook(hf ...hookfunc) { - hooks = append(hooks, hf...) -} - -// Run beego application. -// beego.Run() default run on HttpPort -// beego.Run("localhost") -// beego.Run(":8089") -// beego.Run("127.0.0.1:8089") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Run(params ...string) { - - initBeforeHTTPRun() - - if len(params) > 0 && params[0] != "" { - strs := strings.Split(params[0], ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BConfig.Listen.Domains = params - } - - BeeApp.Run() -} - -// RunWithMiddleWares Run beego application with middlewares. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RunWithMiddleWares(addr string, mws ...MiddleWare) { - initBeforeHTTPRun() - - strs := strings.Split(addr, ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - BConfig.Listen.Domains = []string{strs[0]} - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BeeApp.Run(mws...) -} - -func initBeforeHTTPRun() { - //init hooks - AddAPPStartHook( - registerMime, - registerDefaultErrorHandler, - registerSession, - registerTemplate, - registerAdmin, - registerGzip, - ) - - for _, hk := range hooks { - if err := hk(); err != nil { - panic(err) - } - } -} - -// TestBeegoInit is for test package init -// Deprecated: using pkg/, we will delete this in v2.1.0 -func TestBeegoInit(ap string) { - path := filepath.Join(ap, "conf", "app.conf") - os.Chdir(ap) - InitBeegoBeforeTest(path) -} - -// InitBeegoBeforeTest is for test package init -// Deprecated: using pkg/, we will delete this in v2.1.0 -func InitBeegoBeforeTest(appConfigPath string) { - if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { - panic(err) - } - BConfig.RunMode = "test" - initBeforeHTTPRun() -} diff --git a/build_info.go b/build_info.go deleted file mode 100644 index 59e78127..00000000 --- a/build_info.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var ( - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildVersion string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildGitRevision string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildStatus string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTag string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTime string - - // Deprecated: using pkg/, we will delete this in v2.1.0 - GoVersion string - - // Deprecated: using pkg/, we will delete this in v2.1.0 - GitBranch string -) diff --git a/cache/README.md b/cache/README.md deleted file mode 100644 index b467760a..00000000 --- a/cache/README.md +++ /dev/null @@ -1,59 +0,0 @@ -## cache -cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . - - -## How to install? - - go get github.com/astaxie/beego/cache - - -## What adapters are supported? - -As of now this cache support memory, Memcache and Redis. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/cache" - ) - -Then init a Cache (example with memory adapter) - - bm, err := cache.NewCache("memory", `{"interval":60}`) - -Use it like this: - - bm.Put("astaxie", 1, 10 * time.Second) - bm.Get("astaxie") - bm.IsExist("astaxie") - bm.Delete("astaxie") - - -## Memory adapter - -Configure memory adapter like this: - - {"interval":60} - -interval means the gc time. The cache will check at each time interval, whether item has expired. - - -## Memcache adapter - -Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. - -Configure like this: - - {"conn":"127.0.0.1:11211"} - - -## Redis adapter - -Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. - -Configure like this: - - {"conn":":6039"} diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index 82585c4e..00000000 --- a/cache/cache.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package cache provide a Cache interface and some implement engine -// Usage: -// -// import( -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("memory", `{"interval":60}`) -// -// Use it like this: -// -// bm.Put("astaxie", 1, 10 * time.Second) -// bm.Get("astaxie") -// bm.IsExist("astaxie") -// bm.Delete("astaxie") -// -// more docs http://beego.me/docs/module/cache.md -package cache - -import ( - "fmt" - "time" -) - -// Cache interface contains all behaviors for cache adapter. -// usage: -// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. -// c,err := cache.NewCache("file","{....}") -// c.Put("key",value, 3600 * time.Second) -// v := c.Get("key") -// -// c.Incr("counter") // now is 1 -// c.Incr("counter") // now is 2 -// count := c.Get("counter").(int) -type Cache interface { - // get cached value by key. - Get(key string) interface{} - // GetMulti is a batch version of Get. - GetMulti(keys []string) []interface{} - // set cached value with key and expire time. - Put(key string, val interface{}, timeout time.Duration) error - // delete cached value by key. - Delete(key string) error - // increase cached int value by key, as a counter. - Incr(key string) error - // decrease cached int value by key, as a counter. - Decr(key string) error - // check if cached value exists or not. - IsExist(key string) bool - // clear all cache. - ClearAll() error - // start gc routine based on config string settings. - StartAndGC(config string) error -} - -// Instance is a function create a new Cache Instance -type Instance func() Cache - -var adapters = make(map[string]Instance) - -// Register makes a cache adapter available by the adapter name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, adapter Instance) { - if adapter == nil { - panic("cache: Register adapter is nil") - } - if _, ok := adapters[name]; ok { - panic("cache: Register called twice for adapter " + name) - } - adapters[name] = adapter -} - -// NewCache Create a new cache driver by adapter name and config string. -// config need to be correct JSON as string: {"interval":360}. -// it will start gc automatically. -func NewCache(adapterName, config string) (adapter Cache, err error) { - instanceFunc, ok := adapters[adapterName] - if !ok { - err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) - return - } - adapter = instanceFunc() - err = adapter.StartAndGC(config) - if err != nil { - adapter = nil - } - return -} diff --git a/cache/cache_test.go b/cache/cache_test.go deleted file mode 100644 index 470c0a43..00000000 --- a/cache/cache_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "sync" - "testing" - "time" -) - -func TestCacheIncr(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - //timeoutDuration := 10 * time.Second - - bm.Put("edwardhey", 0, time.Second*20) - wg := sync.WaitGroup{} - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - bm.Incr("edwardhey") - }() - } - wg.Wait() - if bm.Get("edwardhey").(int) != 10 { - t.Error("Incr err") - } -} - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/cache/conv.go b/cache/conv.go deleted file mode 100644 index 87800586..00000000 --- a/cache/conv.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "fmt" - "strconv" -) - -// GetString convert interface to string. -func GetString(v interface{}) string { - switch result := v.(type) { - case string: - return result - case []byte: - return string(result) - default: - if v != nil { - return fmt.Sprint(result) - } - } - return "" -} - -// GetInt convert interface to int. -func GetInt(v interface{}) int { - switch result := v.(type) { - case int: - return result - case int32: - return int(result) - case int64: - return int(result) - default: - if d := GetString(v); d != "" { - value, _ := strconv.Atoi(d) - return value - } - } - return 0 -} - -// GetInt64 convert interface to int64. -func GetInt64(v interface{}) int64 { - switch result := v.(type) { - case int: - return int64(result) - case int32: - return int64(result) - case int64: - return result - default: - - if d := GetString(v); d != "" { - value, _ := strconv.ParseInt(d, 10, 64) - return value - } - } - return 0 -} - -// GetFloat64 convert interface to float64. -func GetFloat64(v interface{}) float64 { - switch result := v.(type) { - case float64: - return result - default: - if d := GetString(v); d != "" { - value, _ := strconv.ParseFloat(d, 64) - return value - } - } - return 0 -} - -// GetBool convert interface to bool. -func GetBool(v interface{}) bool { - switch result := v.(type) { - case bool: - return result - default: - if d := GetString(v); d != "" { - value, _ := strconv.ParseBool(d) - return value - } - } - return false -} diff --git a/cache/conv_test.go b/cache/conv_test.go deleted file mode 100644 index b90e224a..00000000 --- a/cache/conv_test.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "testing" -) - -func TestGetString(t *testing.T) { - var t1 = "test1" - if "test1" != GetString(t1) { - t.Error("get string from string error") - } - var t2 = []byte("test2") - if "test2" != GetString(t2) { - t.Error("get string from byte array error") - } - var t3 = 1 - if "1" != GetString(t3) { - t.Error("get string from int error") - } - var t4 int64 = 1 - if "1" != GetString(t4) { - t.Error("get string from int64 error") - } - var t5 = 1.1 - if "1.1" != GetString(t5) { - t.Error("get string from float64 error") - } - - if "" != GetString(nil) { - t.Error("get string from nil error") - } -} - -func TestGetInt(t *testing.T) { - var t1 = 1 - if 1 != GetInt(t1) { - t.Error("get int from int error") - } - var t2 int32 = 32 - if 32 != GetInt(t2) { - t.Error("get int from int32 error") - } - var t3 int64 = 64 - if 64 != GetInt(t3) { - t.Error("get int from int64 error") - } - var t4 = "128" - if 128 != GetInt(t4) { - t.Error("get int from num string error") - } - if 0 != GetInt(nil) { - t.Error("get int from nil error") - } -} - -func TestGetInt64(t *testing.T) { - var i int64 = 1 - var t1 = 1 - if i != GetInt64(t1) { - t.Error("get int64 from int error") - } - var t2 int32 = 1 - if i != GetInt64(t2) { - t.Error("get int64 from int32 error") - } - var t3 int64 = 1 - if i != GetInt64(t3) { - t.Error("get int64 from int64 error") - } - var t4 = "1" - if i != GetInt64(t4) { - t.Error("get int64 from num string error") - } - if 0 != GetInt64(nil) { - t.Error("get int64 from nil") - } -} - -func TestGetFloat64(t *testing.T) { - var f = 1.11 - var t1 float32 = 1.11 - if f != GetFloat64(t1) { - t.Error("get float64 from float32 error") - } - var t2 = 1.11 - if f != GetFloat64(t2) { - t.Error("get float64 from float64 error") - } - var t3 = "1.11" - if f != GetFloat64(t3) { - t.Error("get float64 from string error") - } - - var f2 float64 = 1 - var t4 = 1 - if f2 != GetFloat64(t4) { - t.Error("get float64 from int error") - } - - if 0 != GetFloat64(nil) { - t.Error("get float64 from nil error") - } -} - -func TestGetBool(t *testing.T) { - var t1 = true - if !GetBool(t1) { - t.Error("get bool from bool error") - } - var t2 = "true" - if !GetBool(t2) { - t.Error("get bool from string error") - } - if GetBool(nil) { - t.Error("get bool from nil error") - } -} - -func byteArrayEquals(a []byte, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} diff --git a/cache/file.go b/cache/file.go deleted file mode 100644 index 6f12d3ee..00000000 --- a/cache/file.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "bytes" - "crypto/md5" - "encoding/gob" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "reflect" - "strconv" - "time" -) - -// FileCacheItem is basic unit of file cache adapter. -// it contains data and expire time. -type FileCacheItem struct { - Data interface{} - Lastaccess time.Time - Expired time.Time -} - -// FileCache Config -var ( - FileCachePath = "cache" // cache directory - FileCacheFileSuffix = ".bin" // cache file suffix - FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. - FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. -) - -// FileCache is cache adapter for file storage. -type FileCache struct { - CachePath string - FileSuffix string - DirectoryLevel int - EmbedExpiry int -} - -// NewFileCache Create new file cache with no config. -// the level and expiry need set in method StartAndGC as config string. -func NewFileCache() Cache { - // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} - return &FileCache{} -} - -// StartAndGC will start and begin gc for file cache. -// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} -func (fc *FileCache) StartAndGC(config string) error { - - cfg := make(map[string]string) - err := json.Unmarshal([]byte(config), &cfg) - if err != nil { - return err - } - if _, ok := cfg["CachePath"]; !ok { - cfg["CachePath"] = FileCachePath - } - if _, ok := cfg["FileSuffix"]; !ok { - cfg["FileSuffix"] = FileCacheFileSuffix - } - if _, ok := cfg["DirectoryLevel"]; !ok { - cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) - } - if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) - } - fc.CachePath = cfg["CachePath"] - fc.FileSuffix = cfg["FileSuffix"] - fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) - fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) - - fc.Init() - return nil -} - -// Init will make new dir for file cache if not exist. -func (fc *FileCache) Init() { - if ok, _ := exists(fc.CachePath); !ok { // todo : error handle - _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle - } -} - -// get cached file name. it's md5 encoded. -func (fc *FileCache) getCacheFileName(key string) string { - m := md5.New() - io.WriteString(m, key) - keyMd5 := hex.EncodeToString(m.Sum(nil)) - cachePath := fc.CachePath - switch fc.DirectoryLevel { - case 2: - cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) - case 1: - cachePath = filepath.Join(cachePath, keyMd5[0:2]) - } - - if ok, _ := exists(cachePath); !ok { // todo : error handle - _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle - } - - return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) -} - -// Get value from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) Get(key string) interface{} { - fileData, err := FileGetContents(fc.getCacheFileName(key)) - if err != nil { - return "" - } - var to FileCacheItem - GobDecode(fileData, &to) - if to.Expired.Before(time.Now()) { - return "" - } - return to.Data -} - -// GetMulti gets values from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) GetMulti(keys []string) []interface{} { - var rc []interface{} - for _, key := range keys { - rc = append(rc, fc.Get(key)) - } - return rc -} - -// Put value into file cache. -// timeout means how long to keep this file, unit of ms. -// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. -func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { - gob.Register(val) - - item := FileCacheItem{Data: val} - if timeout == time.Duration(fc.EmbedExpiry) { - item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years - } else { - item.Expired = time.Now().Add(timeout) - } - item.Lastaccess = time.Now() - data, err := GobEncode(item) - if err != nil { - return err - } - return FilePutContents(fc.getCacheFileName(key), data) -} - -// Delete file cache value. -func (fc *FileCache) Delete(key string) error { - filename := fc.getCacheFileName(key) - if ok, _ := exists(filename); ok { - return os.Remove(filename) - } - return nil -} - -// Incr will increase cached int value. -// fc value is saving forever unless Delete. -func (fc *FileCache) Incr(key string) error { - data := fc.Get(key) - var incr int - if reflect.TypeOf(data).Name() != "int" { - incr = 0 - } else { - incr = data.(int) + 1 - } - fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// Decr will decrease cached int value. -func (fc *FileCache) Decr(key string) error { - data := fc.Get(key) - var decr int - if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { - decr = 0 - } else { - decr = data.(int) - 1 - } - fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// IsExist check value is exist. -func (fc *FileCache) IsExist(key string) bool { - ret, _ := exists(fc.getCacheFileName(key)) - return ret -} - -// ClearAll will clean cached files. -// not implemented. -func (fc *FileCache) ClearAll() error { - return nil -} - -// check file exist. -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -// FileGetContents Get bytes to file. -// if non-exist, create this file. -func FileGetContents(filename string) (data []byte, e error) { - return ioutil.ReadFile(filename) -} - -// FilePutContents Put bytes to file. -// if non-exist, create this file. -func FilePutContents(filename string, content []byte) error { - return ioutil.WriteFile(filename, content, os.ModePerm) -} - -// GobEncode Gob encodes file cache item. -func GobEncode(data interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(data) - if err != nil { - return nil, err - } - return buf.Bytes(), err -} - -// GobDecode Gob decodes file cache item. -func GobDecode(data []byte, to *FileCacheItem) error { - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - return dec.Decode(&to) -} - -func init() { - Register("file", NewFileCache) -} diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go deleted file mode 100644 index 19116bfa..00000000 --- a/cache/memcache/memcache.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memcache for cache provider -// -// depend on github.com/bradfitz/gomemcache/memcache -// -// go install github.com/bradfitz/gomemcache/memcache -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/memcache" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package memcache - -import ( - "encoding/json" - "errors" - "strings" - "time" - - "github.com/astaxie/beego/cache" - "github.com/bradfitz/gomemcache/memcache" -) - -// Cache Memcache adapter. -type Cache struct { - conn *memcache.Client - conninfo []string -} - -// NewMemCache create new memcache adapter. -func NewMemCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - if item, err := rc.conn.Get(key); err == nil { - return item.Value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var rv []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv - } - } - mv, err := rc.conn.GetMulti(keys) - if err == nil { - for _, v := range mv { - rv = append(rv, v.Value) - } - return rv - } - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv -} - -// Put put value to memcache. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} - if v, ok := val.([]byte); ok { - item.Value = v - } else if str, ok := val.(string); ok { - item.Value = []byte(str) - } else { - return errors.New("val only support string and []byte") - } - return rc.conn.Set(&item) -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.Delete(key) -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Increment(key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Decrement(key, 1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - _, err := rc.conn.Get(key) - return err == nil -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.FlushAll() -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - rc.conn = memcache.New(rc.conninfo...) - return nil -} - -func init() { - cache.Register("memcache", NewMemCache) -} diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go deleted file mode 100644 index d9129b69..00000000 --- a/cache/memcache/memcache_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memcache - -import ( - _ "github.com/bradfitz/gomemcache/memcache" - - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestMemcacheCache(t *testing.T) { - bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie").([]byte); string(v) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} diff --git a/cache/memory.go b/cache/memory.go deleted file mode 100644 index d8314e3c..00000000 --- a/cache/memory.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "encoding/json" - "errors" - "sync" - "time" -) - -var ( - // DefaultEvery means the clock time of recycling the expired cache items in memory. - DefaultEvery = 60 // 1 minute -) - -// MemoryItem store memory cache item. -type MemoryItem struct { - val interface{} - createdTime time.Time - lifespan time.Duration -} - -func (mi *MemoryItem) isExpire() bool { - // 0 means forever - if mi.lifespan == 0 { - return false - } - return time.Now().Sub(mi.createdTime) > mi.lifespan -} - -// MemoryCache is Memory cache adapter. -// it contains a RW locker for safe map storage. -type MemoryCache struct { - sync.RWMutex - dur time.Duration - items map[string]*MemoryItem - Every int // run an expiration check Every clock time -} - -// NewMemoryCache returns a new MemoryCache. -func NewMemoryCache() Cache { - cache := MemoryCache{items: make(map[string]*MemoryItem)} - return &cache -} - -// Get cache from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) Get(name string) interface{} { - bc.RLock() - defer bc.RUnlock() - if itm, ok := bc.items[name]; ok { - if itm.isExpire() { - return nil - } - return itm.val - } - return nil -} - -// GetMulti gets caches from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) GetMulti(names []string) []interface{} { - var rc []interface{} - for _, name := range names { - rc = append(rc, bc.Get(name)) - } - return rc -} - -// Put cache to memory. -// if lifespan is 0, it will be forever till restart. -func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { - bc.Lock() - defer bc.Unlock() - bc.items[name] = &MemoryItem{ - val: value, - createdTime: time.Now(), - lifespan: lifespan, - } - return nil -} - -// Delete cache in memory. -func (bc *MemoryCache) Delete(name string) error { - bc.Lock() - defer bc.Unlock() - if _, ok := bc.items[name]; !ok { - return errors.New("key not exist") - } - delete(bc.items, name) - if _, ok := bc.items[name]; ok { - return errors.New("delete key error") - } - return nil -} - -// Incr increase cache counter in memory. -// it supports int,int32,int64,uint,uint32,uint64. -func (bc *MemoryCache) Incr(key string) error { - bc.Lock() - defer bc.Unlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") - } - return nil -} - -// Decr decrease counter in memory. -func (bc *MemoryCache) Decr(key string) error { - bc.Lock() - defer bc.Unlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") - } - return nil -} - -// IsExist check cache exist in memory. -func (bc *MemoryCache) IsExist(name string) bool { - bc.RLock() - defer bc.RUnlock() - if v, ok := bc.items[name]; ok { - return !v.isExpire() - } - return false -} - -// ClearAll will delete all cache in memory. -func (bc *MemoryCache) ClearAll() error { - bc.Lock() - defer bc.Unlock() - bc.items = make(map[string]*MemoryItem) - return nil -} - -// StartAndGC start memory cache. it will check expiration in every clock time. -func (bc *MemoryCache) StartAndGC(config string) error { - var cf map[string]int - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["interval"]; !ok { - cf = make(map[string]int) - cf["interval"] = DefaultEvery - } - dur := time.Duration(cf["interval"]) * time.Second - bc.Every = cf["interval"] - bc.dur = dur - go bc.vacuum() - return nil -} - -// check expiration. -func (bc *MemoryCache) vacuum() { - bc.RLock() - every := bc.Every - bc.RUnlock() - - if every < 1 { - return - } - for { - <-time.After(bc.dur) - bc.RLock() - if bc.items == nil { - bc.RUnlock() - return - } - bc.RUnlock() - if keys := bc.expiredKeys(); len(keys) != 0 { - bc.clearItems(keys) - } - } -} - -// expiredKeys returns key list which are expired. -func (bc *MemoryCache) expiredKeys() (keys []string) { - bc.RLock() - defer bc.RUnlock() - for key, itm := range bc.items { - if itm.isExpire() { - keys = append(keys, key) - } - } - return -} - -// clearItems removes all the items which key in keys. -func (bc *MemoryCache) clearItems(keys []string) { - bc.Lock() - defer bc.Unlock() - for _, key := range keys { - delete(bc.items, key) - } -} - -func init() { - Register("memory", NewMemoryCache) -} diff --git a/cache/redis/redis.go b/cache/redis/redis.go deleted file mode 100644 index d8737b3c..00000000 --- a/cache/redis/redis.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for cache provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/redis" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package redis - -import ( - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/gomodule/redigo/redis" - - "github.com/astaxie/beego/cache" - "strings" -) - -var ( - // DefaultKey the collection name of redis for cache adapter. - DefaultKey = "beecacheRedis" -) - -// Cache is Redis cache adapter. -type Cache struct { - p *redis.Pool // redis connection pool - conninfo string - dbNum int - key string - password string - maxIdle int - - //the timeout to a value less than the redis server's timeout. - timeout time.Duration -} - -// NewRedisCache create new redis cache with default collection name. -func NewRedisCache() cache.Cache { - return &Cache{key: DefaultKey} -} - -// actually do the redis cmds, args[0] must be the key name. -func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { - if len(args) < 1 { - return nil, errors.New("missing required arguments") - } - args[0] = rc.associate(args[0]) - c := rc.p.Get() - defer c.Close() - - return c.Do(commandName, args...) -} - -// associate with config key. -func (rc *Cache) associate(originKey interface{}) string { - return fmt.Sprintf("%s:%s", rc.key, originKey) -} - -// Get cache from redis. -func (rc *Cache) Get(key string) interface{} { - if v, err := rc.do("GET", key); err == nil { - return v - } - return nil -} - -// GetMulti get cache from redis. -func (rc *Cache) GetMulti(keys []string) []interface{} { - c := rc.p.Get() - defer c.Close() - var args []interface{} - for _, key := range keys { - args = append(args, rc.associate(key)) - } - values, err := redis.Values(c.Do("MGET", args...)) - if err != nil { - return nil - } - return values -} - -// Put put cache to redis. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) - return err -} - -// Delete delete cache in redis. -func (rc *Cache) Delete(key string) error { - _, err := rc.do("DEL", key) - return err -} - -// IsExist check cache's existence in redis. -func (rc *Cache) IsExist(key string) bool { - v, err := redis.Bool(rc.do("EXISTS", key)) - if err != nil { - return false - } - return v -} - -// Incr increase counter in redis. -func (rc *Cache) Incr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, 1)) - return err -} - -// Decr decrease counter in redis. -func (rc *Cache) Decr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, -1)) - return err -} - -// ClearAll clean all cache in redis. delete this redis collection. -func (rc *Cache) ClearAll() error { - cachedKeys, err := rc.Scan(rc.key + ":*") - if err != nil { - return err - } - c := rc.p.Get() - defer c.Close() - for _, str := range cachedKeys { - if _, err = c.Do("DEL", str); err != nil { - return err - } - } - return err -} - -// Scan scan all keys matching the pattern. a better choice than `keys` -func (rc *Cache) Scan(pattern string) (keys []string, err error) { - c := rc.p.Get() - defer c.Close() - var ( - cursor uint64 = 0 // start - result []interface{} - list []string - ) - for { - result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) - if err != nil { - return - } - list, err = redis.Strings(result[1], nil) - if err != nil { - return - } - keys = append(keys, list...) - cursor, err = redis.Uint64(result[0], nil) - if err != nil { - return - } - if cursor == 0 { // over - return - } - } -} - -// StartAndGC start redis cache adapter. -// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} -// the cache item in redis are stored forever, -// so no gc operation. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - - if _, ok := cf["key"]; !ok { - cf["key"] = DefaultKey - } - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - - // Format redis://@: - cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) - if i := strings.Index(cf["conn"], "@"); i > -1 { - cf["password"] = cf["conn"][0:i] - cf["conn"] = cf["conn"][i+1:] - } - - if _, ok := cf["dbNum"]; !ok { - cf["dbNum"] = "0" - } - if _, ok := cf["password"]; !ok { - cf["password"] = "" - } - if _, ok := cf["maxIdle"]; !ok { - cf["maxIdle"] = "3" - } - if _, ok := cf["timeout"]; !ok { - cf["timeout"] = "180s" - } - rc.key = cf["key"] - rc.conninfo = cf["conn"] - rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) - rc.password = cf["password"] - rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) - - if v, err := time.ParseDuration(cf["timeout"]); err == nil { - rc.timeout = v - } else { - rc.timeout = 180 * time.Second - } - - rc.connectInit() - - c := rc.p.Get() - defer c.Close() - - return c.Err() -} - -// connect to redis. -func (rc *Cache) connectInit() { - dialFunc := func() (c redis.Conn, err error) { - c, err = redis.Dial("tcp", rc.conninfo) - if err != nil { - return nil, err - } - - if rc.password != "" { - if _, err := c.Do("AUTH", rc.password); err != nil { - c.Close() - return nil, err - } - } - - _, selecterr := c.Do("SELECT", rc.dbNum) - if selecterr != nil { - c.Close() - return nil, selecterr - } - return - } - // initialize a new pool - rc.p = &redis.Pool{ - MaxIdle: rc.maxIdle, - IdleTimeout: rc.timeout, - Dial: dialFunc, - } -} - -func init() { - cache.Register("redis", NewRedisCache) -} diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go deleted file mode 100644 index 60a19180..00000000 --- a/cache/redis/redis_test.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis - -import ( - "fmt" - "testing" - "time" - - "github.com/astaxie/beego/cache" - "github.com/gomodule/redigo/redis" - "github.com/stretchr/testify/assert" -) - -func TestRedisCache(t *testing.T) { - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} - -func TestCache_Scan(t *testing.T) { - timeoutDuration := 10 * time.Second - // init - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) - if err != nil { - t.Error("init err") - } - // insert all - for i := 0; i < 10000; i++ { - if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) - } - } - // scan all for the first time - keys, err := bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - - assert.Equal(t, 10000, len(keys), "scan all error") - - // clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } - - // scan all for the second time - keys, err = bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - if len(keys) != 0 { - t.Error("scan all err") - } -} diff --git a/cache/ssdb/ssdb.go b/cache/ssdb/ssdb.go deleted file mode 100644 index fa2ce04b..00000000 --- a/cache/ssdb/ssdb.go +++ /dev/null @@ -1,231 +0,0 @@ -package ssdb - -import ( - "encoding/json" - "errors" - "strconv" - "strings" - "time" - - "github.com/ssdb/gossdb/ssdb" - - "github.com/astaxie/beego/cache" -) - -// Cache SSDB adapter -type Cache struct { - conn *ssdb.Client - conninfo []string -} - -//NewSsdbCache create new ssdb adapter. -func NewSsdbCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil - } - } - value, err := rc.conn.Get(key) - if err == nil { - return value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var values []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - values = append(values, err) - } - return values - } - } - res, err := rc.conn.Do("multi_get", keys) - resSize := len(res) - if err == nil { - for i := 1; i < resSize; i += 2 { - values = append(values, res[i+1]) - } - return values - } - for i := 0; i < size; i++ { - values = append(values, err) - } - return values -} - -// DelMulti get value from memcache. -func (rc *Cache) DelMulti(keys []string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("multi_del", keys) - return err -} - -// Put put value to memcache. only support string. -func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - v, ok := value.(string) - if !ok { - return errors.New("value must string") - } - var resp []string - var err error - ttl := int(timeout / time.Second) - if ttl < 0 { - resp, err = rc.conn.Do("set", key, v) - } else { - resp, err = rc.conn.Do("setx", key, v, ttl) - } - if err != nil { - return err - } - if len(resp) == 2 && resp[0] == "ok" { - return nil - } - return errors.New("bad response") -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Del(key) - return err -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, -1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - resp, err := rc.conn.Do("exists", key) - if err != nil { - return false - } - if len(resp) == 2 && resp[1] == "1" { - return true - } - return false - -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - keyStart, keyEnd, limit := "", "", 50 - resp, err := rc.Scan(keyStart, keyEnd, limit) - for err == nil { - size := len(resp) - if size == 1 { - return nil - } - keys := []string{} - for i := 1; i < size; i += 2 { - keys = append(keys, resp[i]) - } - _, e := rc.conn.Do("multi_del", keys) - if e != nil { - return e - } - keyStart = resp[size-2] - resp, err = rc.Scan(keyStart, keyEnd, limit) - } - return err -} - -// Scan key all cached in ssdb. -func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } - resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) - if err != nil { - return nil, err - } - return resp, nil -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - conninfoArray := strings.Split(rc.conninfo[0], ":") - host := conninfoArray[0] - port, e := strconv.Atoi(conninfoArray[1]) - if e != nil { - return e - } - var err error - rc.conn, err = ssdb.Connect(host, port) - return err -} - -func init() { - cache.Register("ssdb", NewSsdbCache) -} diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go deleted file mode 100644 index dd474960..00000000 --- a/cache/ssdb/ssdb_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package ssdb - -import ( - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestSsdbcacheCache(t *testing.T) { - ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) - if err != nil { - t.Error("init err") - } - - // test put and exist - if ssdb.IsExist("ssdb") { - t.Error("check err") - } - timeoutDuration := 10 * time.Second - //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - - // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v := ssdb.Get("ssdb"); v != "ssdb" { - t.Error("get Error") - } - - //inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr("ssdb"); err != nil { - t.Error("incr Error", err) - } - - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - - if err = ssdb.Decr("ssdb"); err != nil { - t.Error("decr error") - } - - // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { - t.Error("delete err") - } - } - - //test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { - t.Error("get err") - } - - //test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb1") { - t.Error("check err") - } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } - - // test clear all done - if err = ssdb.ClearAll(); err != nil { - t.Error("clear all err") - } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { - t.Error("check err") - } -} diff --git a/config.go b/config.go deleted file mode 100644 index 7917528e..00000000 --- a/config.go +++ /dev/null @@ -1,557 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "crypto/tls" - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" - "github.com/astaxie/beego/utils" -) - -// Config is the main struct for BConfig -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Config struct { - AppName string //Application name - RunMode string //Running Mode: dev | prod - RouterCaseSensitive bool - ServerName string - RecoverPanic bool - RecoverFunc func(*context.Context) - CopyRequestBody bool - EnableGzip bool - MaxMemory int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig -} - -// Listen holds for http and https related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Listen struct { - Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 - ListenTCP4 bool - EnableHTTP bool - HTTPAddr string - HTTPPort int - AutoTLS bool - Domains []string - TLSCacheDir string - EnableHTTPS bool - EnableMutualHTTPS bool - HTTPSAddr string - HTTPSPort int - HTTPSCertFile string - HTTPSKeyFile string - TrustCaFile string - ClientAuth tls.ClientAuthType - EnableAdmin bool - AdminAddr string - AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O -} - -// WebConfig holds web related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type WebConfig struct { - AutoRender bool - EnableDocs bool - FlashName string - FlashSeparator string - DirectoryIndex bool - StaticDir map[string]string - StaticExtensionsToGzip []string - StaticCacheFileSize int - StaticCacheFileNum int - TemplateLeft string - TemplateRight string - ViewsPath string - EnableXSRF bool - XSRFKey string - XSRFExpire int - Session SessionConfig -} - -// SessionConfig holds session related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params -} - -// LogConfig holds Log related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type LogConfig struct { - AccessLogs bool - EnableStaticLogs bool //log static files requests default: false - AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string - FileLineNum bool - Outputs map[string]string // Store Adaptor : config -} - -var ( - // BConfig is the default config for Application - // Deprecated: using pkg/, we will delete this in v2.1.0 - BConfig *Config - // AppConfig is the instance of Config, store the config information from file - // Deprecated: using pkg/, we will delete this in v2.1.0 - AppConfig *beegoAppConfig - // AppPath is the absolute path to the app - // Deprecated: using pkg/, we will delete this in v2.1.0 - AppPath string - // GlobalSessions is the instance for the session manager - // Deprecated: using pkg/, we will delete this in v2.1.0 - GlobalSessions *session.Manager - - // appConfigPath is the path to the config files - appConfigPath string - // appConfigProvider is the provider for the config, default is ini - appConfigProvider = "ini" - // WorkPath is the absolute path to project root directory - // Deprecated: using pkg/, we will delete this in v2.1.0 - WorkPath string -) - -func init() { - BConfig = newBConfig() - var err error - if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { - panic(err) - } - WorkPath, err = os.Getwd() - if err != nil { - panic(err) - } - var filename = "app.conf" - if os.Getenv("BEEGO_RUNMODE") != "" { - filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" - } - appConfigPath = filepath.Join(WorkPath, "conf", filename) - if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { - appConfigPath = configPath - } - if !utils.FileExists(appConfigPath) { - appConfigPath = filepath.Join(AppPath, "conf", filename) - if !utils.FileExists(appConfigPath) { - AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} - return - } - } - if err = parseConfig(appConfigPath); err != nil { - panic(err) - } -} - -func recoverPanic(ctx *context.Context) { - if err := recover(); err != nil { - if err == ErrAbort { - return - } - if !BConfig.RecoverPanic { - panic(err) - } - if BConfig.EnableErrorsShow { - if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { - exception(fmt.Sprint(err), ctx) - return - } - } - var stack string - logs.Critical("the request url is ", ctx.Input.URL()) - logs.Critical("Handler crashed with error", err) - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - logs.Critical(fmt.Sprintf("%s:%d", file, line)) - stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) - } - if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { - showErr(err, ctx, stack) - } - if ctx.Output.Status != 0 { - ctx.ResponseWriter.WriteHeader(ctx.Output.Status) - } else { - ctx.ResponseWriter.WriteHeader(500) - } - } -} - -func newBConfig() *Config { - return &Config{ - AppName: "beego", - RunMode: PROD, - RouterCaseSensitive: true, - ServerName: "beegoServer:" + VERSION, - RecoverPanic: true, - RecoverFunc: recoverPanic, - CopyRequestBody: false, - EnableGzip: false, - MaxMemory: 1 << 26, //64MB - EnableErrorsShow: true, - EnableErrorsRender: true, - Listen: Listen{ - Graceful: false, - ServerTimeOut: 0, - ListenTCP4: false, - EnableHTTP: true, - AutoTLS: false, - Domains: []string{}, - TLSCacheDir: ".", - HTTPAddr: "", - HTTPPort: 8080, - EnableHTTPS: false, - HTTPSAddr: "", - HTTPSPort: 10443, - HTTPSCertFile: "", - HTTPSKeyFile: "", - EnableAdmin: false, - AdminAddr: "", - AdminPort: 8088, - EnableFcgi: false, - EnableStdIo: false, - ClientAuth: tls.RequireAndVerifyClientCert, - }, - WebConfig: WebConfig{ - AutoRender: true, - EnableDocs: false, - FlashName: "BEEGO_FLASH", - FlashSeparator: "BEEGOFLASH", - DirectoryIndex: false, - StaticDir: map[string]string{"/static": "static"}, - StaticExtensionsToGzip: []string{".css", ".js"}, - StaticCacheFileSize: 1024 * 100, - StaticCacheFileNum: 1000, - TemplateLeft: "{{", - TemplateRight: "}}", - ViewsPath: "views", - EnableXSRF: false, - XSRFKey: "beegoxsrf", - XSRFExpire: 0, - Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionDisableHTTPOnly: false, - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", - SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader: "Beegosessionid", - SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params - }, - }, - Log: LogConfig{ - AccessLogs: false, - EnableStaticLogs: false, - AccessLogsFormat: "APACHE_FORMAT", - FileLineNum: true, - Outputs: map[string]string{"console": ""}, - }, - } -} - -// now only support ini, next will support json. -func parseConfig(appConfigPath string) (err error) { - AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) - if err != nil { - return err - } - return assignConfig(AppConfig) -} - -func assignConfig(ac config.Configer) error { - for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - // set the run mode first - if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { - BConfig.RunMode = envRunMode - } else if runMode := ac.String("RunMode"); runMode != "" { - BConfig.RunMode = runMode - } - - if sd := ac.String("StaticDir"); sd != "" { - BConfig.WebConfig.StaticDir = map[string]string{} - sds := strings.Fields(sd) - for _, v := range sds { - if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] - } else { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] - } - } - } - - if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { - extensions := strings.Split(sgz, ",") - fileExts := []string{} - for _, ext := range extensions { - ext = strings.TrimSpace(ext) - if ext == "" { - continue - } - if !strings.HasPrefix(ext, ".") { - ext = "." + ext - } - fileExts = append(fileExts, ext) - } - if len(fileExts) > 0 { - BConfig.WebConfig.StaticExtensionsToGzip = fileExts - } - } - - if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { - BConfig.WebConfig.StaticCacheFileSize = sfs - } - - if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { - BConfig.WebConfig.StaticCacheFileNum = sfn - } - - if lo := ac.String("LogOutputs"); lo != "" { - // if lo is not nil or empty - // means user has set his own LogOutputs - // clear the default setting to BConfig.Log.Outputs - BConfig.Log.Outputs = make(map[string]string) - los := strings.Split(lo, ";") - for _, v := range los { - if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { - BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] - } else { - continue - } - } - } - - //init log - logs.Reset() - for adaptor, config := range BConfig.Log.Outputs { - err := logs.SetLogger(adaptor, config) - if err != nil { - fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) - } - } - logs.SetLogFuncCall(BConfig.Log.FileLineNum) - - return nil -} - -func assignSingleConfig(p interface{}, ac config.Configer) { - pt := reflect.TypeOf(p) - if pt.Kind() != reflect.Ptr { - return - } - pt = pt.Elem() - if pt.Kind() != reflect.Struct { - return - } - pv := reflect.ValueOf(p).Elem() - - for i := 0; i < pt.NumField(); i++ { - pf := pv.Field(i) - if !pf.CanSet() { - continue - } - name := pt.Field(i).Name - switch pf.Kind() { - case reflect.String: - pf.SetString(ac.DefaultString(name, pf.String())) - case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(name, pf.Int())) - case reflect.Bool: - pf.SetBool(ac.DefaultBool(name, pf.Bool())) - case reflect.Struct: - default: - //do nothing here - } - } - -} - -// LoadAppConfig allow developer to apply a config file -// Deprecated: using pkg/, we will delete this in v2.1.0 -func LoadAppConfig(adapterName, configPath string) error { - absConfigPath, err := filepath.Abs(configPath) - if err != nil { - return err - } - - if !utils.FileExists(absConfigPath) { - return fmt.Errorf("the target config file: %s don't exist", configPath) - } - - appConfigPath = absConfigPath - appConfigProvider = adapterName - - return parseConfig(appConfigPath) -} - -type beegoAppConfig struct { - innerConfig config.Configer -} - -func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { - ac, err := config.NewConfig(appConfigProvider, appConfigPath) - if err != nil { - return nil, err - } - return &beegoAppConfig{ac}, nil -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(key, val) - } - return nil -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) String(key string) string { - if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { - return v - } - return b.innerConfig.String(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { - return v - } - return b.innerConfig.Strings(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int64(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Bool(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Float(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v := b.String(key); v != "" { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v := b.Strings(key); len(v) != 0 { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { - if v, err := b.Int(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { - if v, err := b.Int64(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { - if v, err := b.Bool(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { - if v, err := b.Float(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(section) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(filename) -} diff --git a/config/config.go b/config/config.go deleted file mode 100644 index db2e96f6..00000000 --- a/config/config.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package config is used to parse config. -// Usage: -// import "github.com/astaxie/beego/config" -//Examples. -// -// cnf, err := config.NewConfig("ini", "config.conf") -// -// cnf APIS: -// -// cnf.Set(key, val string) error -// cnf.String(key string) string -// cnf.Strings(key string) []string -// cnf.Int(key string) (int, error) -// cnf.Int64(key string) (int64, error) -// cnf.Bool(key string) (bool, error) -// cnf.Float(key string) (float64, error) -// cnf.DefaultString(key string, defaultVal string) string -// cnf.DefaultStrings(key string, defaultVal []string) []string -// cnf.DefaultInt(key string, defaultVal int) int -// cnf.DefaultInt64(key string, defaultVal int64) int64 -// cnf.DefaultBool(key string, defaultVal bool) bool -// cnf.DefaultFloat(key string, defaultVal float64) float64 -// cnf.DIY(key string) (interface{}, error) -// cnf.GetSection(section string) (map[string]string, error) -// cnf.SaveConfigFile(filename string) error -//More docs http://beego.me/docs/module/config.md -package config - -import ( - "fmt" - "os" - "reflect" - "time" -) - -// Configer defines how to get and set value from configuration raw data. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Configer interface { - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Set(key, val string) error //support section::key type in given key when using ini type. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Strings(key string) []string //get string slice - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Int(key string) (int, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Int64(key string) (int64, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Bool(key string) (bool, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Float(key string) (float64, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultStrings(key string, defaultVal []string) []string //get string slice - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultInt(key string, defaultVal int) int - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultInt64(key string, defaultVal int64) int64 - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultBool(key string, defaultVal bool) bool - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultFloat(key string, defaultVal float64) float64 - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DIY(key string) (interface{}, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - GetSection(section string) (map[string]string, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - SaveConfigFile(filename string) error -} - -// Config is the adapter interface for parsing config file to get raw data to Configer. -type Config interface { - Parse(key string) (Configer, error) - ParseData(data []byte) (Configer, error) -} - -var adapters = make(map[string]Config) - -// Register makes a config adapter available by the adapter name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, adapter Config) { - if adapter == nil { - panic("config: Register adapter is nil") - } - if _, ok := adapters[name]; ok { - panic("config: Register called twice for adapter " + name) - } - adapters[name] = adapter -} - -// NewConfig adapterName is ini/json/xml/yaml. -// filename is the config file path. -func NewConfig(adapterName, filename string) (Configer, error) { - adapter, ok := adapters[adapterName] - if !ok { - return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) - } - return adapter.Parse(filename) -} - -// NewConfigData adapterName is ini/json/xml/yaml. -// data is the config data. -func NewConfigData(adapterName string, data []byte) (Configer, error) { - adapter, ok := adapters[adapterName] - if !ok { - return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) - } - return adapter.ParseData(data) -} - -// ExpandValueEnvForMap convert all string value with environment variable. -func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { - for k, v := range m { - switch value := v.(type) { - case string: - m[k] = ExpandValueEnv(value) - case map[string]interface{}: - m[k] = ExpandValueEnvForMap(value) - case map[string]string: - for k2, v2 := range value { - value[k2] = ExpandValueEnv(v2) - } - m[k] = value - } - } - return m -} - -// ExpandValueEnv returns value of convert with environment variable. -// -// Return environment variable if value start with "${" and end with "}". -// Return default value if environment variable is empty or not exist. -// -// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". -// Examples: -// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. -// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". -// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". -func ExpandValueEnv(value string) (realValue string) { - realValue = value - - vLen := len(value) - // 3 = ${} - if vLen < 3 { - return - } - // Need start with "${" and end with "}", then return. - if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { - return - } - - key := "" - defaultV := "" - // value start with "${" - for i := 2; i < vLen; i++ { - if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { - key = value[2:i] - defaultV = value[i+2 : vLen-1] // other string is default value. - break - } else if value[i] == '}' { - key = value[2:i] - break - } - } - - realValue = os.Getenv(key) - if realValue == "" { - realValue = defaultV - } - - return -} - -// ParseBool returns the boolean value represented by the string. -// -// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, -// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. -// Any other value returns an error. -func ParseBool(val interface{}) (value bool, err error) { - if val != nil { - switch v := val.(type) { - case bool: - return v, nil - case string: - switch v { - case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": - return true, nil - case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": - return false, nil - } - case int8, int32, int64: - strV := fmt.Sprintf("%d", v) - if strV == "1" { - return true, nil - } else if strV == "0" { - return false, nil - } - case float64: - if v == 1.0 { - return true, nil - } else if v == 0.0 { - return false, nil - } - } - return false, fmt.Errorf("parsing %q: invalid syntax", val) - } - return false, fmt.Errorf("parsing : invalid syntax") -} - -// ToString converts values of any type to string. -func ToString(x interface{}) string { - switch y := x.(type) { - - // Handle dates with special logic - // This needs to come above the fmt.Stringer - // test since time.Time's have a .String() - // method - case time.Time: - return y.Format("A Monday") - - // Handle type string - case string: - return y - - // Handle type with .String() method - case fmt.Stringer: - return y.String() - - // Handle type with .Error() method - case error: - return y.Error() - - } - - // Handle named string type - if v := reflect.ValueOf(x); v.Kind() == reflect.String { - return v.String() - } - - // Fallback to fmt package for anything else like numeric types - return fmt.Sprint(x) -} diff --git a/config/config_test.go b/config/config_test.go deleted file mode 100644 index 15d6ffa6..00000000 --- a/config/config_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "os" - "testing" -) - -func TestExpandValueEnv(t *testing.T) { - - testCases := []struct { - item string - want string - }{ - {"", ""}, - {"$", "$"}, - {"{", "{"}, - {"{}", "{}"}, - {"${}", ""}, - {"${|}", ""}, - {"${}", ""}, - {"${{}}", ""}, - {"${{||}}", "}"}, - {"${pwd||}", ""}, - {"${pwd||}", ""}, - {"${pwd||}", ""}, - {"${pwd||}}", "}"}, - {"${pwd||{{||}}}", "{{||}}"}, - {"${GOPATH}", os.Getenv("GOPATH")}, - {"${GOPATH||}", os.Getenv("GOPATH")}, - {"${GOPATH||root}", os.Getenv("GOPATH")}, - {"${GOPATH_NOT||root}", "root"}, - {"${GOPATH_NOT||||root}", "||root"}, - } - - for _, c := range testCases { - if got := ExpandValueEnv(c.item); got != c.want { - t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) - } - } - -} diff --git a/config/env/env.go b/config/env/env.go deleted file mode 100644 index 1a6c2527..00000000 --- a/config/env/env.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// Copyright 2017 Faissal Elamraoui. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package env is used to parse environment. -package env - -import ( - "fmt" - "os" - "strings" - - "github.com/astaxie/beego/utils" -) - -var env *utils.BeeMap - -func init() { - env = utils.NewBeeMap() - for _, e := range os.Environ() { - splits := strings.Split(e, "=") - env.Set(splits[0], os.Getenv(splits[0])) - } -} - -// Get returns a value by key. -// If the key does not exist, the default value will be returned. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func Get(key string, defVal string) string { - if val := env.Get(key); val != nil { - return val.(string) - } - return defVal -} - -// MustGet returns a value by key. -// If the key does not exist, it will return an error. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func MustGet(key string) (string, error) { - if val := env.Get(key); val != nil { - return val.(string), nil - } - return "", fmt.Errorf("no env variable with %s", key) -} - -// Set sets a value in the ENV copy. -// This does not affect the child process environment. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func Set(key string, value string) { - env.Set(key, value) -} - -// MustSet sets a value in the ENV copy and the child process environment. -// It returns an error in case the set operation failed. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func MustSet(key string, value string) error { - err := os.Setenv(key, value) - if err != nil { - return err - } - env.Set(key, value) - return nil -} - -// GetAll returns all keys/values in the current child process environment. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func GetAll() map[string]string { - items := env.Items() - envs := make(map[string]string, env.Count()) - - for key, val := range items { - switch key := key.(type) { - case string: - switch val := val.(type) { - case string: - envs[key] = val - } - } - } - return envs -} diff --git a/config/env/env_test.go b/config/env/env_test.go deleted file mode 100644 index 3f1d4dba..00000000 --- a/config/env/env_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// Copyright 2017 Faissal Elamraoui. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package env - -import ( - "os" - "testing" -) - -func TestEnvGet(t *testing.T) { - gopath := Get("GOPATH", "") - if gopath != os.Getenv("GOPATH") { - t.Error("expected GOPATH not empty.") - } - - noExistVar := Get("NOEXISTVAR", "foo") - if noExistVar != "foo" { - t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) - } -} - -func TestEnvMustGet(t *testing.T) { - gopath, err := MustGet("GOPATH") - if err != nil { - t.Error(err) - } - - if gopath != os.Getenv("GOPATH") { - t.Errorf("expected GOPATH to be the same, got %s.", gopath) - } - - _, err = MustGet("NOEXISTVAR") - if err == nil { - t.Error("expected error to be non-nil") - } -} - -func TestEnvSet(t *testing.T) { - Set("MYVAR", "foo") - myVar := Get("MYVAR", "bar") - if myVar != "foo" { - t.Errorf("expected MYVAR to equal foo, got %s.", myVar) - } -} - -func TestEnvMustSet(t *testing.T) { - err := MustSet("FOO", "bar") - if err != nil { - t.Error(err) - } - - fooVar := os.Getenv("FOO") - if fooVar != "bar" { - t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) - } -} - -func TestEnvGetAll(t *testing.T) { - envMap := GetAll() - if len(envMap) == 0 { - t.Error("expected environment not empty.") - } -} diff --git a/config/fake.go b/config/fake.go deleted file mode 100644 index 8093ad61..00000000 --- a/config/fake.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "errors" - "strconv" - "strings" -) - -type fakeConfigContainer struct { - data map[string]string -} - -func (c *fakeConfigContainer) getData(key string) string { - return c.data[strings.ToLower(key)] -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Set(key, val string) error { - c.data[strings.ToLower(key)] = val - return nil -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) String(key string) string { - return c.getData(key) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.getData(key)) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.getData(key), 10, 64) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getData(key)) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.getData(key), 64) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { - if v, ok := c.data[strings.ToLower(key)]; ok { - return v, nil - } - return nil, errors.New("key not find") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { - return nil, errors.New("not implement in the fakeConfigContainer") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) SaveConfigFile(filename string) error { - return errors.New("not implement in the fakeConfigContainer") -} - -var _ Configer = new(fakeConfigContainer) - -// NewFakeConfig return a fake Configer -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func NewFakeConfig() Configer { - return &fakeConfigContainer{ - data: make(map[string]string), - } -} diff --git a/config/ini.go b/config/ini.go deleted file mode 100644 index 1da293dc..00000000 --- a/config/ini.go +++ /dev/null @@ -1,524 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "bufio" - "bytes" - "errors" - "io" - "io/ioutil" - "os" - "os/user" - "path/filepath" - "strconv" - "strings" - "sync" -) - -var ( - defaultSection = "default" // default section means if some ini items not in a section, make them in default section, - bNumComment = []byte{'#'} // number signal - bSemComment = []byte{';'} // semicolon signal - bEmpty = []byte{} - bEqual = []byte{'='} // equal signal - bDQuote = []byte{'"'} // quote signal - sectionStart = []byte{'['} // section start signal - sectionEnd = []byte{']'} // section end signal - lineBreak = "\n" -) - -// IniConfig implements Config to parse ini file. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type IniConfig struct { -} - -// Parse creates a new Config and parses the file configuration from the named file. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (ini *IniConfig) Parse(name string) (Configer, error) { - return ini.parseFile(name) -} - -func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { - data, err := ioutil.ReadFile(name) - if err != nil { - return nil, err - } - - return ini.parseData(filepath.Dir(name), data) -} - -func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { - cfg := &IniConfigContainer{ - data: make(map[string]map[string]string), - sectionComment: make(map[string]string), - keyComment: make(map[string]string), - RWMutex: sync.RWMutex{}, - } - cfg.Lock() - defer cfg.Unlock() - - var comment bytes.Buffer - buf := bufio.NewReader(bytes.NewBuffer(data)) - // check the BOM - head, err := buf.Peek(3) - if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { - for i := 1; i <= 3; i++ { - buf.ReadByte() - } - } - section := defaultSection - tmpBuf := bytes.NewBuffer(nil) - for { - tmpBuf.Reset() - - shouldBreak := false - for { - tmp, isPrefix, err := buf.ReadLine() - if err == io.EOF { - shouldBreak = true - break - } - - //It might be a good idea to throw a error on all unknonw errors? - if _, ok := err.(*os.PathError); ok { - return nil, err - } - - tmpBuf.Write(tmp) - if isPrefix { - continue - } - - if !isPrefix { - break - } - } - if shouldBreak { - break - } - - line := tmpBuf.Bytes() - line = bytes.TrimSpace(line) - if bytes.Equal(line, bEmpty) { - continue - } - var bComment []byte - switch { - case bytes.HasPrefix(line, bNumComment): - bComment = bNumComment - case bytes.HasPrefix(line, bSemComment): - bComment = bSemComment - } - if bComment != nil { - line = bytes.TrimLeft(line, string(bComment)) - // Need append to a new line if multi-line comments. - if comment.Len() > 0 { - comment.WriteByte('\n') - } - comment.Write(line) - continue - } - - if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { - section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive - if comment.Len() > 0 { - cfg.sectionComment[section] = comment.String() - comment.Reset() - } - if _, ok := cfg.data[section]; !ok { - cfg.data[section] = make(map[string]string) - } - continue - } - - if _, ok := cfg.data[section]; !ok { - cfg.data[section] = make(map[string]string) - } - keyValue := bytes.SplitN(line, bEqual, 2) - - key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive - key = strings.ToLower(key) - - // handle include "other.conf" - if len(keyValue) == 1 && strings.HasPrefix(key, "include") { - - includefiles := strings.Fields(key) - if includefiles[0] == "include" && len(includefiles) == 2 { - - otherfile := strings.Trim(includefiles[1], "\"") - if !filepath.IsAbs(otherfile) { - otherfile = filepath.Join(dir, otherfile) - } - - i, err := ini.parseFile(otherfile) - if err != nil { - return nil, err - } - - for sec, dt := range i.data { - if _, ok := cfg.data[sec]; !ok { - cfg.data[sec] = make(map[string]string) - } - for k, v := range dt { - cfg.data[sec][k] = v - } - } - - for sec, comm := range i.sectionComment { - cfg.sectionComment[sec] = comm - } - - for k, comm := range i.keyComment { - cfg.keyComment[k] = comm - } - - continue - } - } - - if len(keyValue) != 2 { - return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") - } - val := bytes.TrimSpace(keyValue[1]) - if bytes.HasPrefix(val, bDQuote) { - val = bytes.Trim(val, `"`) - } - - cfg.data[section][key] = ExpandValueEnv(string(val)) - if comment.Len() > 0 { - cfg.keyComment[section+"."+key] = comment.String() - comment.Reset() - } - - } - return cfg, nil -} - -// ParseData parse ini the data -// When include other.conf,other.conf is either absolute directory -// or under beego in default temporary directory(/tmp/beego[-username]). -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (ini *IniConfig) ParseData(data []byte) (Configer, error) { - dir := "beego" - currentUser, err := user.Current() - if err == nil { - dir = "beego-" + currentUser.Username - } - dir = filepath.Join(os.TempDir(), dir) - if err = os.MkdirAll(dir, os.ModePerm); err != nil { - return nil, err - } - - return ini.parseData(dir, data) -} - -// IniConfigContainer A Config represents the ini configuration. -// When set and get value, support key as section:name type. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type IniConfigContainer struct { - data map[string]map[string]string // section=> key:val - sectionComment map[string]string // section : comment - keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getdata(key)) -} - -// DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.getdata(key)) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.getdata(key), 10, 64) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.getdata(key), 64) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) String(key string) string { - return c.getdata(key) -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Return nil if config value does not exist or is empty. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v, nil - } - return nil, errors.New("not exist section") -} - -// SaveConfigFile save the config into file. -// -// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - - // Get section or key comments. Fixed #1607 - getCommentStr := func(section, key string) string { - var ( - comment string - ok bool - ) - if len(key) == 0 { - comment, ok = c.sectionComment[section] - } else { - comment, ok = c.keyComment[section+"."+key] - } - - if ok { - // Empty comment - if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { - return string(bNumComment) - } - prefix := string(bNumComment) - // Add the line head character "#" - return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) - } - return "" - } - - buf := bytes.NewBuffer(nil) - // Save default section at first place - if dt, ok := c.data[defaultSection]; ok { - for key, val := range dt { - if key != " " { - // Write key comments. - if v := getCommentStr(defaultSection, key); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write key and value. - if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { - return err - } - } - } - - // Put a line between sections. - if _, err = buf.WriteString(lineBreak); err != nil { - return err - } - } - // Save named sections - for section, dt := range c.data { - if section != defaultSection { - // Write section comments. - if v := getCommentStr(section, ""); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write section name. - if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { - return err - } - - for key, val := range dt { - if key != " " { - // Write key comments. - if v := getCommentStr(section, key); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write key and value. - if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { - return err - } - } - } - - // Put a line between sections. - if _, err = buf.WriteString(lineBreak); err != nil { - return err - } - } - } - _, err = buf.WriteTo(f) - return err -} - -// Set writes a new value for key. -// if write to one section, the key need be "section::key". -// if the section is not existed, it panics. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Set(key, value string) error { - c.Lock() - defer c.Unlock() - if len(key) == 0 { - return errors.New("key is empty") - } - - var ( - section, k string - sectionKey = strings.Split(strings.ToLower(key), "::") - ) - - if len(sectionKey) >= 2 { - section = sectionKey[0] - k = sectionKey[1] - } else { - section = defaultSection - k = sectionKey[0] - } - - if _, ok := c.data[section]; !ok { - c.data[section] = make(map[string]string) - } - c.data[section][k] = value - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { - if v, ok := c.data[strings.ToLower(key)]; ok { - return v, nil - } - return v, errors.New("key not find") -} - -// section.key or key -func (c *IniConfigContainer) getdata(key string) string { - if len(key) == 0 { - return "" - } - c.RLock() - defer c.RUnlock() - - var ( - section, k string - sectionKey = strings.Split(strings.ToLower(key), "::") - ) - if len(sectionKey) >= 2 { - section = sectionKey[0] - k = sectionKey[1] - } else { - section = defaultSection - k = sectionKey[0] - } - if v, ok := c.data[section]; ok { - if vv, ok := v[k]; ok { - return vv - } - } - return "" -} - -func init() { - Register("ini", &IniConfig{}) -} diff --git a/config/ini_test.go b/config/ini_test.go deleted file mode 100644 index ffcdb294..00000000 --- a/config/ini_test.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "fmt" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestIni(t *testing.T) { - - var ( - inicontext = ` -;comment one -#comment two -appname = beeapi -httpport = 8080 -mysqlport = 3600 -PI = 3.1415976 -runmode = "dev" -autorender = false -copyrequestbody = true -session= on -cookieon= off -newreg = OFF -needlogin = ON -enableSession = Y -enableCookie = N -flag = 1 -path1 = ${GOPATH} -path2 = ${GOPATH||/home/go} -[demo] -key1="asta" -key2 = "xie" -CaseInsensitive = true -peers = one;two;three -password = ${GOPATH} -` - - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "pi": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "session": true, - "cookieon": false, - "newreg": false, - "needlogin": true, - "enableSession": true, - "enableCookie": false, - "flag": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "demo::key1": "asta", - "demo::key2": "xie", - "demo::CaseInsensitive": true, - "demo::peers": []string{"one", "two", "three"}, - "demo::password": os.Getenv("GOPATH"), - "null": "", - "demo2::key1": "", - "error": "", - "emptystrings": []string{}, - } - ) - - f, err := os.Create("testini.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(inicontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testini.conf") - iniconf, err := NewConfig("ini", "testini.conf") - if err != nil { - t.Fatal(err) - } - for k, v := range keyValue { - var err error - var value interface{} - switch v.(type) { - case int: - value, err = iniconf.Int(k) - case int64: - value, err = iniconf.Int64(k) - case float64: - value, err = iniconf.Float(k) - case bool: - value, err = iniconf.Bool(k) - case []string: - value = iniconf.Strings(k) - case string: - value = iniconf.String(k) - default: - value, err = iniconf.DIY(k) - } - if err != nil { - t.Fatalf("get key %q value fail,err %s", k, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - - } - if err = iniconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if iniconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - -} - -func TestIniSave(t *testing.T) { - - const ( - inicontext = ` -app = app -;comment one -#comment two -# comment three -appname = beeapi -httpport = 8080 -# DB Info -# enable db -[dbinfo] -# db type name -# suport mysql,sqlserver -name = mysql -` - - saveResult = ` -app=app -#comment one -#comment two -# comment three -appname=beeapi -httpport=8080 - -# DB Info -# enable db -[dbinfo] -# db type name -# suport mysql,sqlserver -name=mysql -` - ) - cfg, err := NewConfigData("ini", []byte(inicontext)) - if err != nil { - t.Fatal(err) - } - name := "newIniConfig.ini" - if err := cfg.SaveConfigFile(name); err != nil { - t.Fatal(err) - } - defer os.Remove(name) - - if data, err := ioutil.ReadFile(name); err != nil { - t.Fatal(err) - } else { - cfgData := string(data) - datas := strings.Split(saveResult, "\n") - for _, line := range datas { - if !strings.Contains(cfgData, line+"\n") { - t.Fatalf("different after save ini config file. need contains %q", line) - } - } - - } -} diff --git a/config/json.go b/config/json.go deleted file mode 100644 index 74a50d34..00000000 --- a/config/json.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "os" - "strconv" - "strings" - "sync" -) - -// JSONConfig is a json config parser and implements Config interface. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type JSONConfig struct { -} - -// Parse returns a ConfigContainer with parsed json config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (js *JSONConfig) Parse(filename string) (Configer, error) { - file, err := os.Open(filename) - if err != nil { - return nil, err - } - defer file.Close() - content, err := ioutil.ReadAll(file) - if err != nil { - return nil, err - } - - return js.ParseData(content) -} - -// ParseData returns a ConfigContainer with json string -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (js *JSONConfig) ParseData(data []byte) (Configer, error) { - x := &JSONConfigContainer{ - data: make(map[string]interface{}), - } - err := json.Unmarshal(data, &x.data) - if err != nil { - var wrappingArray []interface{} - err2 := json.Unmarshal(data, &wrappingArray) - if err2 != nil { - return nil, err - } - x.data["rootArray"] = wrappingArray - } - - x.data = ExpandValueEnvForMap(x.data) - - return x, nil -} - -// JSONConfigContainer A Config represents the json configuration. -// Only when get value, support key as section:name type. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type JSONConfigContainer struct { - data map[string]interface{} - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Bool(key string) (bool, error) { - val := c.getData(key) - if val != nil { - return ParseBool(val) - } - return false, fmt.Errorf("not exist key: %q", key) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err == nil { - return v - } - return defaultval -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Int(key string) (int, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return int(v), nil - } else if v, ok := val.(string); ok { - return strconv.Atoi(v) - } - return 0, errors.New("not valid value") - } - return 0, errors.New("not exist key:" + key) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err == nil { - return v - } - return defaultval -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Int64(key string) (int64, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return int64(v), nil - } - return 0, errors.New("not int64 value") - } - return 0, errors.New("not exist key:" + key) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err == nil { - return v - } - return defaultval -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Float(key string) (float64, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return v, nil - } - return 0.0, errors.New("not float64 value") - } - return 0.0, errors.New("not exist key:" + key) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err == nil { - return v - } - return defaultval -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) String(key string) string { - val := c.getData(key) - if val != nil { - if v, ok := val.(string); ok { - return v - } - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { - // TODO FIXME should not use "" to replace non existence - if v := c.String(key); v != "" { - return v - } - return defaultval -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Strings(key string) []string { - stringVal := c.String(key) - if stringVal == "" { - return nil - } - return strings.Split(c.String(key), ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); v != nil { - return v - } - return defaultval -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil - } - return nil, errors.New("nonexist section " + section) -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - b, err := json.MarshalIndent(c.data, "", " ") - if err != nil { - return err - } - _, err = f.Write(b) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { - val := c.getData(key) - if val != nil { - return val, nil - } - return nil, errors.New("not exist key") -} - -// section.key or key -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) getData(key string) interface{} { - if len(key) == 0 { - return nil - } - - c.RLock() - defer c.RUnlock() - - sectionKeys := strings.Split(key, "::") - if len(sectionKeys) >= 2 { - curValue, ok := c.data[sectionKeys[0]] - if !ok { - return nil - } - for _, key := range sectionKeys[1:] { - if v, ok := curValue.(map[string]interface{}); ok { - if curValue, ok = v[key]; !ok { - return nil - } - } - } - return curValue - } - if v, ok := c.data[key]; ok { - return v - } - return nil -} - -func init() { - Register("json", &JSONConfig{}) -} diff --git a/config/json_test.go b/config/json_test.go deleted file mode 100644 index 16f42409..00000000 --- a/config/json_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "fmt" - "os" - "testing" -) - -func TestJsonStartsWithArray(t *testing.T) { - - const jsoncontextwitharray = `[ - { - "url": "user", - "serviceAPI": "http://www.test.com/user" - }, - { - "url": "employee", - "serviceAPI": "http://www.test.com/employee" - } -]` - f, err := os.Create("testjsonWithArray.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(jsoncontextwitharray) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") - if err != nil { - t.Fatal(err) - } - rootArray, err := jsonconf.DIY("rootArray") - if err != nil { - t.Error("array does not exist as element") - } - rootArrayCasted := rootArray.([]interface{}) - if rootArrayCasted == nil { - t.Error("array from root is nil") - } else { - elem := rootArrayCasted[0].(map[string]interface{}) - if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { - t.Error("array[0] values are not valid") - } - - elem2 := rootArrayCasted[1].(map[string]interface{}) - if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { - t.Error("array[1] values are not valid") - } - } -} - -func TestJson(t *testing.T) { - - var ( - jsoncontext = `{ -"appname": "beeapi", -"testnames": "foo;bar", -"httpport": 8080, -"mysqlport": 3600, -"PI": 3.1415976, -"runmode": "dev", -"autorender": false, -"copyrequestbody": true, -"session": "on", -"cookieon": "off", -"newreg": "OFF", -"needlogin": "ON", -"enableSession": "Y", -"enableCookie": "N", -"flag": 1, -"path1": "${GOPATH}", -"path2": "${GOPATH||/home/go}", -"database": { - "host": "host", - "port": "port", - "database": "database", - "username": "username", - "password": "${GOPATH}", - "conns":{ - "maxconnection":12, - "autoconnect":true, - "connectioninfo":"info", - "root": "${GOPATH}" - } - } -}` - keyValue = map[string]interface{}{ - "appname": "beeapi", - "testnames": []string{"foo", "bar"}, - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "session": true, - "cookieon": false, - "newreg": false, - "needlogin": true, - "enableSession": true, - "enableCookie": false, - "flag": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "database::host": "host", - "database::port": "port", - "database::database": "database", - "database::password": os.Getenv("GOPATH"), - "database::conns::maxconnection": 12, - "database::conns::autoconnect": true, - "database::conns::connectioninfo": "info", - "database::conns::root": os.Getenv("GOPATH"), - "unknown": "", - } - ) - - f, err := os.Create("testjson.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(jsoncontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") - if err != nil { - t.Fatal(err) - } - - for k, v := range keyValue { - var err error - var value interface{} - switch v.(type) { - case int: - value, err = jsonconf.Int(k) - case int64: - value, err = jsonconf.Int64(k) - case float64: - value, err = jsonconf.Float(k) - case bool: - value, err = jsonconf.Bool(k) - case []string: - value = jsonconf.Strings(k) - case string: - value = jsonconf.String(k) - default: - value, err = jsonconf.DIY(k) - } - if err != nil { - t.Fatalf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - - } - if err = jsonconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if jsonconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - - if db, err := jsonconf.DIY("database"); err != nil { - t.Fatal(err) - } else if m, ok := db.(map[string]interface{}); !ok { - t.Log(db) - t.Fatal("db not map[string]interface{}") - } else { - if m["host"].(string) != "host" { - t.Fatal("get host err") - } - } - - if _, err := jsonconf.Int("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int") - } - - if _, err := jsonconf.Int64("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int64") - } - - if _, err := jsonconf.Float("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Float") - } - - if _, err := jsonconf.DIY("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an interface{}") - } - - if val := jsonconf.String("unknown"); val != "" { - t.Error("unknown keys should return an empty string when expecting a String") - } - - if _, err := jsonconf.Bool("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Bool") - } - - if !jsonconf.DefaultBool("unknown", true) { - t.Error("unknown keys with default value wrong") - } -} diff --git a/config/xml/xml.go b/config/xml/xml.go deleted file mode 100644 index 1601561f..00000000 --- a/config/xml/xml.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package xml for config provider. -// -// depend on github.com/beego/x2j. -// -// go install github.com/beego/x2j. -// -// Usage: -// import( -// _ "github.com/astaxie/beego/config/xml" -// "github.com/astaxie/beego/config" -// ) -// -// cnf, err := config.NewConfig("xml", "config.xml") -// -//More docs http://beego.me/docs/module/config.md -package xml - -import ( - "encoding/xml" - "errors" - "fmt" - "io/ioutil" - "os" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/config" - "github.com/beego/x2j" -) - -// Config is a xml config parser and implements Config interface. -// xml configurations should be included in tag. -// only support key/value pair as value as each item. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Config struct{} - -// Parse returns a ConfigContainer with parsed xml config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (xc *Config) Parse(filename string) (config.Configer, error) { - context, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err - } - - return xc.ParseData(context) -} - -// ParseData xml data -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (xc *Config) ParseData(data []byte) (config.Configer, error) { - x := &ConfigContainer{data: make(map[string]interface{})} - - d, err := x2j.DocToMap(string(data)) - if err != nil { - return nil, err - } - - x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) - - return x, nil -} - -// ConfigContainer A Config represents the xml configuration. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type ConfigContainer struct { - data map[string]interface{} - sync.Mutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Bool(key string) (bool, error) { - if v := c.data[key]; v != nil { - return config.ParseBool(v) - } - return false, fmt.Errorf("not exist key: %q", key) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.data[key].(string)) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.data[key].(string), 10, 64) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v - -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.data[key].(string), 64) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) String(key string) string { - if v, ok := c.data[key].(string); ok { - return v - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section].(map[string]interface{}); ok { - mapstr := make(map[string]string) - for k, val := range v { - mapstr[k] = config.ToString(val) - } - return mapstr, nil - } - return nil, fmt.Errorf("section '%s' not found", section) -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - b, err := xml.MarshalIndent(c.data, " ", " ") - if err != nil { - return err - } - _, err = f.Write(b) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { - if v, ok := c.data[key]; ok { - return v, nil - } - return nil, errors.New("not exist key") -} - -func init() { - config.Register("xml", &Config{}) -} diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go deleted file mode 100644 index 346c866e..00000000 --- a/config/xml/xml_test.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package xml - -import ( - "fmt" - "os" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestXML(t *testing.T) { - - var ( - //xml parse should incluce in tags - xmlcontext = ` - -beeapi -8080 -3600 -3.1415976 -dev -false -true -${GOPATH} -${GOPATH||/home/go} - -1 -MySection - - -` - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "error": "", - "emptystrings": []string{}, - } - ) - - f, err := os.Create("testxml.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(xmlcontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testxml.conf") - - xmlconf, err := config.NewConfig("xml", "testxml.conf") - if err != nil { - t.Fatal(err) - } - - var xmlsection map[string]string - xmlsection, err = xmlconf.GetSection("mysection") - if err != nil { - t.Fatal(err) - } - - if len(xmlsection) == 0 { - t.Error("section should not be empty") - } - - for k, v := range keyValue { - - var ( - value interface{} - err error - ) - - switch v.(type) { - case int: - value, err = xmlconf.Int(k) - case int64: - value, err = xmlconf.Int64(k) - case float64: - value, err = xmlconf.Float(k) - case bool: - value, err = xmlconf.Bool(k) - case []string: - value = xmlconf.Strings(k) - case string: - value = xmlconf.String(k) - default: - value, err = xmlconf.DIY(k) - } - if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Errorf("get key %q value, want %v got %v .", k, v, value) - } - - } - - if err = xmlconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if xmlconf.String("name") != "astaxie" { - t.Fatal("get name error") - } -} diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go deleted file mode 100644 index 725f905b..00000000 --- a/config/yaml/yaml.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package yaml for config provider -// -// depend on github.com/beego/goyaml2 -// -// go install github.com/beego/goyaml2 -// -// Usage: -// import( -// _ "github.com/astaxie/beego/config/yaml" -// "github.com/astaxie/beego/config" -// ) -// -// cnf, err := config.NewConfig("yaml", "config.yaml") -// -//More docs http://beego.me/docs/module/config.md -package yaml - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "log" - "os" - "strings" - "sync" - - "github.com/astaxie/beego/config" - "github.com/beego/goyaml2" -) - -// Config is a yaml config parser and implements Config interface. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Config struct{} - -// Parse returns a ConfigContainer with parsed yaml config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (yaml *Config) Parse(filename string) (y config.Configer, err error) { - cnf, err := ReadYmlReader(filename) - if err != nil { - return - } - y = &ConfigContainer{ - data: cnf, - } - return -} - -// ParseData parse yaml data -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (yaml *Config) ParseData(data []byte) (config.Configer, error) { - cnf, err := parseYML(data) - if err != nil { - return nil, err - } - - return &ConfigContainer{ - data: cnf, - }, nil -} - -// ReadYmlReader Read yaml file to map. -// if json like, use json package, unless goyaml2 package. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { - buf, err := ioutil.ReadFile(path) - if err != nil { - return - } - - return parseYML(buf) -} - -// parseYML parse yaml formatted []byte to map. -func parseYML(buf []byte) (cnf map[string]interface{}, err error) { - if len(buf) < 3 { - return - } - - if string(buf[0:1]) == "{" { - log.Println("Look like a Json, try json umarshal") - err = json.Unmarshal(buf, &cnf) - if err == nil { - log.Println("It is Json Map") - return - } - } - - data, err := goyaml2.Read(bytes.NewReader(buf)) - if err != nil { - log.Println("Goyaml2 ERR>", string(buf), err) - return - } - - if data == nil { - log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) - return - } - cnf, ok := data.(map[string]interface{}) - if !ok { - log.Println("Not a Map? >> ", string(buf), data) - cnf = nil - } - cnf = config.ExpandValueEnvForMap(cnf) - return -} - -// ConfigContainer A Config represents the yaml configuration. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type ConfigContainer struct { - data map[string]interface{} - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Bool(key string) (bool, error) { - v, err := c.getData(key) - if err != nil { - return false, err - } - return config.ParseBool(v) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int(key string) (int, error) { - if v, err := c.getData(key); err != nil { - return 0, err - } else if vv, ok := v.(int); ok { - return vv, nil - } else if vv, ok := v.(int64); ok { - return int(vv), nil - } - return 0, errors.New("not int value") -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int64(key string) (int64, error) { - if v, err := c.getData(key); err != nil { - return 0, err - } else if vv, ok := v.(int64); ok { - return vv, nil - } - return 0, errors.New("not bool value") -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Float(key string) (float64, error) { - if v, err := c.getData(key); err != nil { - return 0.0, err - } else if vv, ok := v.(float64); ok { - return vv, nil - } else if vv, ok := v.(int); ok { - return float64(vv), nil - } else if vv, ok := v.(int64); ok { - return float64(vv), nil - } - return 0.0, errors.New("not float64 value") -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) String(key string) string { - if v, err := c.getData(key); err == nil { - if vv, ok := v.(string); ok { - return vv - } - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil - } - return nil, errors.New("not exist section") -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - err = goyaml2.Write(f, c.data) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { - return c.getData(key) -} - -func (c *ConfigContainer) getData(key string) (interface{}, error) { - - if len(key) == 0 { - return nil, errors.New("key is empty") - } - c.RLock() - defer c.RUnlock() - - keys := strings.Split(key, ".") - tmpData := c.data - for idx, k := range keys { - if v, ok := tmpData[k]; ok { - switch v.(type) { - case map[string]interface{}: - { - tmpData = v.(map[string]interface{}) - if idx == len(keys)-1 { - return tmpData, nil - } - } - default: - { - return v, nil - } - - } - } - } - return nil, fmt.Errorf("not exist key %q", key) -} - -func init() { - config.Register("yaml", &Config{}) -} diff --git a/config/yaml/yaml_test.go b/config/yaml/yaml_test.go deleted file mode 100644 index 49cc1d1e..00000000 --- a/config/yaml/yaml_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package yaml - -import ( - "fmt" - "os" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestYaml(t *testing.T) { - - var ( - yamlcontext = ` -"appname": beeapi -"httpport": 8080 -"mysqlport": 3600 -"PI": 3.1415976 -"runmode": dev -"autorender": false -"copyrequestbody": true -"PATH": GOPATH -"path1": ${GOPATH} -"path2": ${GOPATH||/home/go} -"empty": "" -` - - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "PATH": "GOPATH", - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "error": "", - "emptystrings": []string{}, - } - ) - f, err := os.Create("testyaml.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(yamlcontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testyaml.conf") - yamlconf, err := config.NewConfig("yaml", "testyaml.conf") - if err != nil { - t.Fatal(err) - } - - if yamlconf.String("appname") != "beeapi" { - t.Fatal("appname not equal to beeapi") - } - - for k, v := range keyValue { - - var ( - value interface{} - err error - ) - - switch v.(type) { - case int: - value, err = yamlconf.Int(k) - case int64: - value, err = yamlconf.Int64(k) - case float64: - value, err = yamlconf.Float(k) - case bool: - value, err = yamlconf.Bool(k) - case []string: - value = yamlconf.Strings(k) - case string: - value = yamlconf.String(k) - default: - value, err = yamlconf.DIY(k) - } - if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Errorf("get key %q value, want %v got %v .", k, v, value) - } - - } - - if err = yamlconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if yamlconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - -} diff --git a/config_test.go b/config_test.go deleted file mode 100644 index 5f71f1c3..00000000 --- a/config_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "encoding/json" - "reflect" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestDefaults(t *testing.T) { - if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { - t.Errorf("FlashName was not set to default.") - } - - if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { - t.Errorf("FlashName was not set to default.") - } -} - -func TestAssignConfig_01(t *testing.T) { - _BConfig := &Config{} - _BConfig.AppName = "beego_test" - jcf := &config.JSONConfig{} - ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) - assignSingleConfig(_BConfig, ac) - if _BConfig.AppName != "beego_json" { - t.Log(_BConfig) - t.FailNow() - } -} - -func TestAssignConfig_02(t *testing.T) { - _BConfig := &Config{} - bs, _ := json.Marshal(newBConfig()) - - jsonMap := M{} - json.Unmarshal(bs, &jsonMap) - - configMap := M{} - for k, v := range jsonMap { - if reflect.TypeOf(v).Kind() == reflect.Map { - for k1, v1 := range v.(M) { - if reflect.TypeOf(v1).Kind() == reflect.Map { - for k2, v2 := range v1.(M) { - configMap[k2] = v2 - } - } else { - configMap[k1] = v1 - } - } - } else { - configMap[k] = v - } - } - configMap["MaxMemory"] = 1024 - configMap["Graceful"] = true - configMap["XSRFExpire"] = 32 - configMap["SessionProviderConfig"] = "file" - configMap["FileLineNum"] = true - - jcf := &config.JSONConfig{} - bs, _ = json.Marshal(configMap) - ac, _ := jcf.ParseData(bs) - - for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - - if _BConfig.MaxMemory != 1024 { - t.Log(_BConfig.MaxMemory) - t.FailNow() - } - - if !_BConfig.Listen.Graceful { - t.Log(_BConfig.Listen.Graceful) - t.FailNow() - } - - if _BConfig.WebConfig.XSRFExpire != 32 { - t.Log(_BConfig.WebConfig.XSRFExpire) - t.FailNow() - } - - if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { - t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) - t.FailNow() - } - - if !_BConfig.Log.FileLineNum { - t.Log(_BConfig.Log.FileLineNum) - t.FailNow() - } - -} - -func TestAssignConfig_03(t *testing.T) { - jcf := &config.JSONConfig{} - ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) - ac.Set("AppName", "test_app") - ac.Set("RunMode", "online") - ac.Set("StaticDir", "download:down download2:down2") - ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") - ac.Set("StaticCacheFileSize", "87456") - ac.Set("StaticCacheFileNum", "1254") - assignConfig(ac) - - t.Logf("%#v", BConfig) - - if BConfig.AppName != "test_app" { - t.FailNow() - } - - if BConfig.RunMode != "online" { - t.FailNow() - } - if BConfig.WebConfig.StaticDir["/download"] != "down" { - t.FailNow() - } - if BConfig.WebConfig.StaticDir["/download2"] != "down2" { - t.FailNow() - } - if BConfig.WebConfig.StaticCacheFileSize != 87456 { - t.FailNow() - } - if BConfig.WebConfig.StaticCacheFileNum != 1254 { - t.FailNow() - } - if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { - t.FailNow() - } -} diff --git a/context/acceptencoder.go b/context/acceptencoder.go deleted file mode 100644 index b4e2492c..00000000 --- a/context/acceptencoder.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2015 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "compress/zlib" - "io" - "net/http" - "os" - "strconv" - "strings" - "sync" -) - -var ( - //Default size==20B same as nginx - defaultGzipMinLength = 20 - //Content will only be compressed if content length is either unknown or greater than gzipMinLength. - gzipMinLength = defaultGzipMinLength - //The compression level used for deflate compression. (0-9). - gzipCompressLevel int - //List of HTTP methods to compress. If not set, only GET requests are compressed. - includedMethods map[string]bool - getMethodOnly bool -) - -// InitGzip init the gzipcompress -func InitGzip(minLength, compressLevel int, methods []string) { - if minLength >= 0 { - gzipMinLength = minLength - } - gzipCompressLevel = compressLevel - if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { - gzipCompressLevel = flate.BestSpeed - } - getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") - includedMethods = make(map[string]bool, len(methods)) - for _, v := range methods { - includedMethods[strings.ToUpper(v)] = true - } -} - -type resetWriter interface { - io.Writer - Reset(w io.Writer) -} - -type nopResetWriter struct { - io.Writer -} - -func (n nopResetWriter) Reset(w io.Writer) { - //do nothing -} - -type acceptEncoder struct { - name string - levelEncode func(int) resetWriter - customCompressLevelPool *sync.Pool - bestCompressionPool *sync.Pool -} - -func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { - if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { - return nopResetWriter{wr} - } - var rwr resetWriter - switch level { - case flate.BestSpeed: - rwr = ac.customCompressLevelPool.Get().(resetWriter) - case flate.BestCompression: - rwr = ac.bestCompressionPool.Get().(resetWriter) - default: - rwr = ac.levelEncode(level) - } - rwr.Reset(wr) - return rwr -} - -func (ac acceptEncoder) put(wr resetWriter, level int) { - if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { - return - } - wr.Reset(nil) - - //notice - //compressionLevel==BestCompression DOES NOT MATTER - //sync.Pool will not memory leak - - switch level { - case gzipCompressLevel: - ac.customCompressLevelPool.Put(wr) - case flate.BestCompression: - ac.bestCompressionPool.Put(wr) - } -} - -var ( - noneCompressEncoder = acceptEncoder{"", nil, nil, nil} - gzipCompressEncoder = acceptEncoder{ - name: "gzip", - levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, - customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, - bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, - } - - //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed - //deflate - //The "zlib" format defined in RFC 1950 [31] in combination with - //the "deflate" compression mechanism described in RFC 1951 [29]. - deflateCompressEncoder = acceptEncoder{ - name: "deflate", - levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, - customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, - bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, - } -) - -var ( - encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore - "gzip": gzipCompressEncoder, - "deflate": deflateCompressEncoder, - "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip - "identity": noneCompressEncoder, // identity means none-compress - } -) - -// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) -func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { - return writeLevel(encoding, writer, file, flate.BestCompression) -} - -// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) -func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { - if encoding == "" || len(content) < gzipMinLength { - _, err := writer.Write(content) - return false, "", err - } - return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) -} - -// writeLevel reads from reader,writes to writer by specific encoding and compress level -// the compress level is defined by deflate package -func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { - var outputWriter resetWriter - var err error - var ce = noneCompressEncoder - - if cf, ok := encoderMap[encoding]; ok { - ce = cf - } - encoding = ce.name - outputWriter = ce.encode(writer, level) - defer ce.put(outputWriter, level) - - _, err = io.Copy(outputWriter, reader) - if err != nil { - return false, "", err - } - - switch outputWriter.(type) { - case io.WriteCloser: - outputWriter.(io.WriteCloser).Close() - } - return encoding != "", encoding, nil -} - -// ParseEncoding will extract the right encoding for response -// the Accept-Encoding's sec is here: -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 -func ParseEncoding(r *http.Request) string { - if r == nil { - return "" - } - if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { - return parseEncoding(r) - } - return "" -} - -type q struct { - name string - value float64 -} - -func parseEncoding(r *http.Request) string { - acceptEncoding := r.Header.Get("Accept-Encoding") - if acceptEncoding == "" { - return "" - } - var lastQ q - for _, v := range strings.Split(acceptEncoding, ",") { - v = strings.TrimSpace(v) - if v == "" { - continue - } - vs := strings.Split(v, ";") - var cf acceptEncoder - var ok bool - if cf, ok = encoderMap[vs[0]]; !ok { - continue - } - if len(vs) == 1 { - return cf.name - } - if len(vs) == 2 { - f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) - if f == 0 { - continue - } - if f > lastQ.value { - lastQ = q{cf.name, f} - } - } - } - return lastQ.name -} diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go deleted file mode 100644 index e3d61e27..00000000 --- a/context/acceptencoder_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2015 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "testing" -) - -func Test_ExtractEncoding(t *testing.T) { - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { - t.Fail() - } -} diff --git a/context/context.go b/context/context.go deleted file mode 100644 index 7c161ac0..00000000 --- a/context/context.go +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provide the context utils -// Usage: -// -// import "github.com/astaxie/beego/context" -// -// ctx := context.Context{Request:req,ResponseWriter:rw} -// -// more docs http://beego.me/docs/module/context.md -package context - -import ( - "bufio" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "errors" - "fmt" - "net" - "net/http" - "strconv" - "strings" - "time" - - "github.com/astaxie/beego/utils" -) - -//commonly used mime-types -const ( - ApplicationJSON = "application/json" - ApplicationXML = "application/xml" - ApplicationYAML = "application/x-yaml" - TextXML = "text/xml" -) - -// NewContext return the Context with Input and Output -func NewContext() *Context { - return &Context{ - Input: NewInput(), - Output: NewOutput(), - } -} - -// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. -// BeegoInput and BeegoOutput provides some api to operate request and response more easily. -type Context struct { - Input *BeegoInput - Output *BeegoOutput - Request *http.Request - ResponseWriter *Response - _xsrfToken string -} - -// Reset init Context, BeegoInput and BeegoOutput -func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { - ctx.Request = r - if ctx.ResponseWriter == nil { - ctx.ResponseWriter = &Response{} - } - ctx.ResponseWriter.reset(rw) - ctx.Input.Reset(ctx) - ctx.Output.Reset(ctx) - ctx._xsrfToken = "" -} - -// Redirect does redirection to localurl with http header status code. -func (ctx *Context) Redirect(status int, localurl string) { - http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) -} - -// Abort stops this request. -// if beego.ErrorMaps exists, panic body. -func (ctx *Context) Abort(status int, body string) { - ctx.Output.SetStatus(status) - panic(body) -} - -// WriteString Write string to response body. -// it sends response body. -func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) -} - -// GetCookie Get cookie from request by a given key. -// It's alias of BeegoInput.Cookie. -func (ctx *Context) GetCookie(key string) string { - return ctx.Input.Cookie(key) -} - -// SetCookie Set cookie for response. -// It's alias of BeegoOutput.Cookie. -func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { - ctx.Output.Cookie(name, value, others...) -} - -// GetSecureCookie Get secure cookie from request by a given key. -func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { - val := ctx.Input.Cookie(key) - if val == "" { - return "", false - } - - parts := strings.SplitN(val, "|", 3) - - if len(parts) != 3 { - return "", false - } - - vs := parts[0] - timestamp := parts[1] - sig := parts[2] - - h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - - if fmt.Sprintf("%02x", h.Sum(nil)) != sig { - return "", false - } - res, _ := base64.URLEncoding.DecodeString(vs) - return string(res), true -} - -// SetSecureCookie Set Secure cookie for response. -func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { - vs := base64.URLEncoding.EncodeToString([]byte(value)) - timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) - h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - sig := fmt.Sprintf("%02x", h.Sum(nil)) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - ctx.Output.Cookie(name, cookie, others...) -} - -// XSRFToken creates a xsrf token string and returns. -func (ctx *Context) XSRFToken(key string, expire int64) string { - if ctx._xsrfToken == "" { - token, ok := ctx.GetSecureCookie(key, "_xsrf") - if !ok { - token = string(utils.RandomCreateBytes(32)) - ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "", true, true) - } - ctx._xsrfToken = token - } - return ctx._xsrfToken -} - -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" -// or in form field value named as "_xsrf". -func (ctx *Context) CheckXSRFCookie() bool { - token := ctx.Input.Query("_xsrf") - if token == "" { - token = ctx.Request.Header.Get("X-Xsrftoken") - } - if token == "" { - token = ctx.Request.Header.Get("X-Csrftoken") - } - if token == "" { - ctx.Abort(422, "422") - return false - } - if ctx._xsrfToken != token { - ctx.Abort(417, "417") - return false - } - return true -} - -// RenderMethodResult renders the return value of a controller method to the output -func (ctx *Context) RenderMethodResult(result interface{}) { - if result != nil { - renderer, ok := result.(Renderer) - if !ok { - err, ok := result.(error) - if ok { - renderer = errorRenderer(err) - } else { - renderer = jsonRenderer(result) - } - } - renderer.Render(ctx) - } -} - -//Response is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler -type Response struct { - http.ResponseWriter - Started bool - Status int - Elapsed time.Duration -} - -func (r *Response) reset(rw http.ResponseWriter) { - r.ResponseWriter = rw - r.Status = 0 - r.Started = false -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (r *Response) Write(p []byte) (int, error) { - r.Started = true - return r.ResponseWriter.Write(p) -} - -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. -func (r *Response) WriteHeader(code int) { - if r.Status > 0 { - //prevent multiple response.WriteHeader calls - return - } - r.Status = code - r.Started = true - r.ResponseWriter.WriteHeader(code) -} - -// Hijack hijacker for http -func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := r.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("webserver doesn't support hijacking") - } - return hj.Hijack() -} - -// Flush http.Flusher -func (r *Response) Flush() { - if f, ok := r.ResponseWriter.(http.Flusher); ok { - f.Flush() - } -} - -// CloseNotify http.CloseNotifier -func (r *Response) CloseNotify() <-chan bool { - if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { - return cn.CloseNotify() - } - return nil -} - -// Pusher http.Pusher -func (r *Response) Pusher() (pusher http.Pusher) { - if pusher, ok := r.ResponseWriter.(http.Pusher); ok { - return pusher - } - return nil -} diff --git a/context/context_test.go b/context/context_test.go deleted file mode 100644 index e81e8191..00000000 --- a/context/context_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestXsrfReset_01(t *testing.T) { - r := &http.Request{} - c := NewContext() - c.Request = r - c.ResponseWriter = &Response{} - c.ResponseWriter.reset(httptest.NewRecorder()) - c.Output.Reset(c) - c.Input.Reset(c) - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - token := c._xsrfToken - c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) - if c._xsrfToken != "" { - t.FailNow() - } - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - if token == c._xsrfToken { - t.FailNow() - } - - ck := c.ResponseWriter.Header().Get("Set-Cookie") - assert.True(t, strings.Contains(ck, "Secure")) - assert.True(t, strings.Contains(ck, "HttpOnly")) -} diff --git a/context/input.go b/context/input.go deleted file mode 100644 index c2c1c63d..00000000 --- a/context/input.go +++ /dev/null @@ -1,709 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "compress/gzip" - "errors" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "reflect" - "regexp" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/session" -) - -// Regexes for checking the accept headers -// TODO make sure these are correct -var ( - acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) - acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) - acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) - acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) - maxParam = 50 -) - -// BeegoInput operates the http request header, data, cookie and body. -// it also contains router params and current session. -type BeegoInput struct { - Context *Context - CruSession session.Store - pnames []string - pvalues []string - data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. - dataLock sync.RWMutex - RequestBody []byte - RunMethod string - RunController reflect.Type -} - -// NewInput return BeegoInput generated by Context. -func NewInput() *BeegoInput { - return &BeegoInput{ - pnames: make([]string, 0, maxParam), - pvalues: make([]string, 0, maxParam), - data: make(map[interface{}]interface{}), - } -} - -// Reset init the BeegoInput -func (input *BeegoInput) Reset(ctx *Context) { - input.Context = ctx - input.CruSession = nil - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] - input.dataLock.Lock() - input.data = nil - input.dataLock.Unlock() - input.RequestBody = []byte{} -} - -// Protocol returns request protocol name, such as HTTP/1.1 . -func (input *BeegoInput) Protocol() string { - return input.Context.Request.Proto -} - -// URI returns full request url with query string, fragment. -func (input *BeegoInput) URI() string { - return input.Context.Request.RequestURI -} - -// URL returns request url path (without query string, fragment). -func (input *BeegoInput) URL() string { - return input.Context.Request.URL.EscapedPath() -} - -// Site returns base site url as scheme://domain type. -func (input *BeegoInput) Site() string { - return input.Scheme() + "://" + input.Domain() -} - -// Scheme returns request scheme as "http" or "https". -func (input *BeegoInput) Scheme() string { - if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { - return scheme - } - if input.Context.Request.URL.Scheme != "" { - return input.Context.Request.URL.Scheme - } - if input.Context.Request.TLS == nil { - return "http" - } - return "https" -} - -// Domain returns host name. -// Alias of Host method. -func (input *BeegoInput) Domain() string { - return input.Host() -} - -// Host returns host name. -// if no host info in request, return localhost. -func (input *BeegoInput) Host() string { - if input.Context.Request.Host != "" { - if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { - return hostPart - } - return input.Context.Request.Host - } - return "localhost" -} - -// Method returns http request method. -func (input *BeegoInput) Method() string { - return input.Context.Request.Method -} - -// Is returns boolean of this request is on given method, such as Is("POST"). -func (input *BeegoInput) Is(method string) bool { - return input.Method() == method -} - -// IsGet Is this a GET method request? -func (input *BeegoInput) IsGet() bool { - return input.Is("GET") -} - -// IsPost Is this a POST method request? -func (input *BeegoInput) IsPost() bool { - return input.Is("POST") -} - -// IsHead Is this a Head method request? -func (input *BeegoInput) IsHead() bool { - return input.Is("HEAD") -} - -// IsOptions Is this a OPTIONS method request? -func (input *BeegoInput) IsOptions() bool { - return input.Is("OPTIONS") -} - -// IsPut Is this a PUT method request? -func (input *BeegoInput) IsPut() bool { - return input.Is("PUT") -} - -// IsDelete Is this a DELETE method request? -func (input *BeegoInput) IsDelete() bool { - return input.Is("DELETE") -} - -// IsPatch Is this a PATCH method request? -func (input *BeegoInput) IsPatch() bool { - return input.Is("PATCH") -} - -// IsAjax returns boolean of this request is generated by ajax. -func (input *BeegoInput) IsAjax() bool { - return input.Header("X-Requested-With") == "XMLHttpRequest" -} - -// IsSecure returns boolean of this request is in https. -func (input *BeegoInput) IsSecure() bool { - return input.Scheme() == "https" -} - -// IsWebsocket returns boolean of this request is in webSocket. -func (input *BeegoInput) IsWebsocket() bool { - return input.Header("Upgrade") == "websocket" -} - -// IsUpload returns boolean of whether file uploads in this request or not.. -func (input *BeegoInput) IsUpload() bool { - return strings.Contains(input.Header("Content-Type"), "multipart/form-data") -} - -// AcceptsHTML Checks if request accepts html response -func (input *BeegoInput) AcceptsHTML() bool { - return acceptsHTMLRegex.MatchString(input.Header("Accept")) -} - -// AcceptsXML Checks if request accepts xml response -func (input *BeegoInput) AcceptsXML() bool { - return acceptsXMLRegex.MatchString(input.Header("Accept")) -} - -// AcceptsJSON Checks if request accepts json response -func (input *BeegoInput) AcceptsJSON() bool { - return acceptsJSONRegex.MatchString(input.Header("Accept")) -} - -// AcceptsYAML Checks if request accepts json response -func (input *BeegoInput) AcceptsYAML() bool { - return acceptsYAMLRegex.MatchString(input.Header("Accept")) -} - -// IP returns request client ip. -// if in proxy, return first proxy id. -// if error, return RemoteAddr. -func (input *BeegoInput) IP() string { - ips := input.Proxy() - if len(ips) > 0 && ips[0] != "" { - rip, _, err := net.SplitHostPort(ips[0]) - if err != nil { - rip = ips[0] - } - return rip - } - if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { - return ip - } - return input.Context.Request.RemoteAddr -} - -// Proxy returns proxy client ips slice. -func (input *BeegoInput) Proxy() []string { - if ips := input.Header("X-Forwarded-For"); ips != "" { - return strings.Split(ips, ",") - } - return []string{} -} - -// Referer returns http referer header. -func (input *BeegoInput) Referer() string { - return input.Header("Referer") -} - -// Refer returns http referer header. -func (input *BeegoInput) Refer() string { - return input.Referer() -} - -// SubDomains returns sub domain string. -// if aa.bb.domain.com, returns aa.bb . -func (input *BeegoInput) SubDomains() string { - parts := strings.Split(input.Host(), ".") - if len(parts) >= 3 { - return strings.Join(parts[:len(parts)-2], ".") - } - return "" -} - -// Port returns request client port. -// when error or empty, return 80. -func (input *BeegoInput) Port() int { - if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { - port, _ := strconv.Atoi(portPart) - return port - } - return 80 -} - -// UserAgent returns request client user agent string. -func (input *BeegoInput) UserAgent() string { - return input.Header("User-Agent") -} - -// ParamsLen return the length of the params -func (input *BeegoInput) ParamsLen() int { - return len(input.pnames) -} - -// Param returns router param by a given key. -func (input *BeegoInput) Param(key string) string { - for i, v := range input.pnames { - if v == key && i <= len(input.pvalues) { - // we cannot use url.PathEscape(input.pvalues[i]) - // for example, if the value is /a/b - // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb - // However, the value is used in ControllerRegister.ServeHTTP - // and split by "/", so function crash... - return input.pvalues[i] - } - } - return "" -} - -// Params returns the map[key]value. -func (input *BeegoInput) Params() map[string]string { - m := make(map[string]string) - for i, v := range input.pnames { - if i <= len(input.pvalues) { - m[v] = input.pvalues[i] - } - } - return m -} - -// SetParam will set the param with key and value -func (input *BeegoInput) SetParam(key, val string) { - // check if already exists - for i, v := range input.pnames { - if v == key && i <= len(input.pvalues) { - input.pvalues[i] = val - return - } - } - input.pvalues = append(input.pvalues, val) - input.pnames = append(input.pnames, key) -} - -// ResetParams clears any of the input's Params -// This function is used to clear parameters so they may be reset between filter -// passes. -func (input *BeegoInput) ResetParams() { - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] -} - -// ResetData: reset data -func (input *BeegoInput) ResetData() { - input.dataLock.Lock() - input.data = nil - input.dataLock.Unlock() -} - -// ResetBody: reset body -func (input *BeegoInput) ResetBody() { - input.RequestBody = []byte{} -} - -// Clear: clear all data in input -func (input *BeegoInput) Clear() { - input.ResetParams() - input.ResetData() - input.ResetBody() - -} - -// Query returns input data item string by a given string. -func (input *BeegoInput) Query(key string) string { - if val := input.Param(key); val != "" { - return val - } - if input.Context.Request.Form == nil { - input.dataLock.Lock() - if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() - } - input.dataLock.Unlock() - } - input.dataLock.RLock() - defer input.dataLock.RUnlock() - return input.Context.Request.Form.Get(key) -} - -// Header returns request header item string by a given string. -// if non-existed, return empty string. -func (input *BeegoInput) Header(key string) string { - return input.Context.Request.Header.Get(key) -} - -// Cookie returns request cookie item string by a given key. -// if non-existed, return empty string. -func (input *BeegoInput) Cookie(key string) string { - ck, err := input.Context.Request.Cookie(key) - if err != nil { - return "" - } - return ck.Value -} - -// Session returns current session item value by a given key. -// if non-existed, return nil. -func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) -} - -// CopyBody returns the raw request body data as bytes. -func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { - if input.Context.Request.Body == nil { - return []byte{} - } - - var requestbody []byte - safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} - if input.Header("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(safe) - if err != nil { - return nil - } - requestbody, _ = ioutil.ReadAll(reader) - } else { - requestbody, _ = ioutil.ReadAll(safe) - } - - input.Context.Request.Body.Close() - bf := bytes.NewBuffer(requestbody) - input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) - input.RequestBody = requestbody - return requestbody -} - -// Data return the implicit data in the input -func (input *BeegoInput) Data() map[interface{}]interface{} { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if input.data == nil { - input.data = make(map[interface{}]interface{}) - } - return input.data -} - -// GetData returns the stored data in this context. -func (input *BeegoInput) GetData(key interface{}) interface{} { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if v, ok := input.data[key]; ok { - return v - } - return nil -} - -// SetData stores data with given key in this context. -// This data are only available in this context. -func (input *BeegoInput) SetData(key, val interface{}) { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if input.data == nil { - input.data = make(map[interface{}]interface{}) - } - input.data[key] = val -} - -// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type -func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { - // Parse the body depending on the content type. - if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { - if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { - return errors.New("Error parsing request body:" + err.Error()) - } - } else if err := input.Context.Request.ParseForm(); err != nil { - return errors.New("Error parsing request body:" + err.Error()) - } - return nil -} - -// Bind data from request.Form[key] to dest -// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie -// var id int beegoInput.Bind(&id, "id") id ==123 -// var isok bool beegoInput.Bind(&isok, "isok") isok ==true -// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 -// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] -// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] -// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} -func (input *BeegoInput) Bind(dest interface{}, key string) error { - value := reflect.ValueOf(dest) - if value.Kind() != reflect.Ptr { - return errors.New("beego: non-pointer passed to Bind: " + key) - } - value = value.Elem() - if !value.CanSet() { - return errors.New("beego: non-settable variable passed to Bind: " + key) - } - typ := value.Type() - // Get real type if dest define with interface{}. - // e.g var dest interface{} dest=1.0 - if value.Kind() == reflect.Interface { - typ = value.Elem().Type() - } - rv := input.bind(key, typ) - if !rv.IsValid() { - return errors.New("beego: reflect value is empty") - } - value.Set(rv) - return nil -} - -func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { - if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() - } - rv := reflect.Zero(typ) - switch typ.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindInt(val, typ) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindUint(val, typ) - case reflect.Float32, reflect.Float64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindFloat(val, typ) - case reflect.String: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindString(val, typ) - case reflect.Bool: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindBool(val, typ) - case reflect.Slice: - rv = input.bindSlice(&input.Context.Request.Form, key, typ) - case reflect.Struct: - rv = input.bindStruct(&input.Context.Request.Form, key, typ) - case reflect.Ptr: - rv = input.bindPoint(key, typ) - case reflect.Map: - rv = input.bindMap(&input.Context.Request.Form, key, typ) - } - return rv -} - -func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { - rv := reflect.Zero(typ) - switch typ.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - rv = input.bindInt(val, typ) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - rv = input.bindUint(val, typ) - case reflect.Float32, reflect.Float64: - rv = input.bindFloat(val, typ) - case reflect.String: - rv = input.bindString(val, typ) - case reflect.Bool: - rv = input.bindBool(val, typ) - case reflect.Slice: - rv = input.bindSlice(&url.Values{"": {val}}, "", typ) - case reflect.Struct: - rv = input.bindStruct(&url.Values{"": {val}}, "", typ) - case reflect.Ptr: - rv = input.bindPoint(val, typ) - case reflect.Map: - rv = input.bindMap(&url.Values{"": {val}}, "", typ) - } - return rv -} - -func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { - intValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetInt(intValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { - uintValue, err := strconv.ParseUint(val, 10, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetUint(uintValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { - floatValue, err := strconv.ParseFloat(val, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetFloat(floatValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { - return reflect.ValueOf(val) -} - -func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { - val = strings.TrimSpace(strings.ToLower(val)) - switch val { - case "true", "on", "1": - return reflect.ValueOf(true) - } - return reflect.ValueOf(false) -} - -type sliceValue struct { - index int // Index extracted from brackets. If -1, no index was provided. - value reflect.Value // the bound value for this slice element. -} - -func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { - maxIndex := -1 - numNoIndex := 0 - sliceValues := []sliceValue{} - for reqKey, vals := range *params { - if !strings.HasPrefix(reqKey, key+"[") { - continue - } - // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) - index := -1 - leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) - if rightBracket > leftBracket+1 { - index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) - } - subKeyIndex := rightBracket + 1 - - // Handle the indexed case. - if index > -1 { - if index > maxIndex { - maxIndex = index - } - sliceValues = append(sliceValues, sliceValue{ - index: index, - value: input.bind(reqKey[:subKeyIndex], typ.Elem()), - }) - continue - } - - // It's an un-indexed element. (e.g. element[]) - numNoIndex += len(vals) - for _, val := range vals { - // Unindexed values can only be direct-bound. - sliceValues = append(sliceValues, sliceValue{ - index: -1, - value: input.bindValue(val, typ.Elem()), - }) - } - } - resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) - for _, sv := range sliceValues { - if sv.index != -1 { - resultArray.Index(sv.index).Set(sv.value) - } else { - resultArray = reflect.Append(resultArray, sv.value) - } - } - return resultArray -} - -func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { - result := reflect.New(typ).Elem() - fieldValues := make(map[string]reflect.Value) - for reqKey, val := range *params { - var fieldName string - if strings.HasPrefix(reqKey, key+".") { - fieldName = reqKey[len(key)+1:] - } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { - fieldName = reqKey[len(key)+1 : len(reqKey)-1] - } else { - continue - } - - if _, ok := fieldValues[fieldName]; !ok { - // Time to bind this field. Get it and make sure we can set it. - fieldValue := result.FieldByName(fieldName) - if !fieldValue.IsValid() { - continue - } - if !fieldValue.CanSet() { - continue - } - boundVal := input.bindValue(val[0], fieldValue.Type()) - fieldValue.Set(boundVal) - fieldValues[fieldName] = boundVal - } - } - - return result -} - -func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { - return input.bind(key, typ.Elem()).Addr() -} - -func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { - var ( - result = reflect.MakeMap(typ) - keyType = typ.Key() - valueType = typ.Elem() - ) - for paramName, values := range *params { - if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { - continue - } - - key := paramName[len(key)+1 : len(paramName)-1] - result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) - } - return result -} diff --git a/context/input_test.go b/context/input_test.go deleted file mode 100644 index 3a6c2e7b..00000000 --- a/context/input_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "net/http/httptest" - "reflect" - "testing" -) - -func TestBind(t *testing.T) { - type testItem struct { - field string - empty interface{} - want interface{} - } - type Human struct { - ID int - Nick string - Pwd string - Ms bool - } - - cases := []struct { - request string - valueGp []testItem - }{ - {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, - - {"/?p=", []testItem{{"p", "", ""}}}, - {"/?p=str", []testItem{{"p", "", "str"}}}, - - {"/?p=123", []testItem{{"p", 0, 123}}}, - {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, - - {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, - {"/?p=1", []testItem{{"p", false, true}}}, - - {"/?p=true", []testItem{{"p", false, true}}}, - {"/?p=ON", []testItem{{"p", false, true}}}, - {"/?p=on", []testItem{{"p", false, true}}}, - {"/?p=1", []testItem{{"p", false, true}}}, - {"/?p=2", []testItem{{"p", false, false}}}, - {"/?p=false", []testItem{{"p", false, false}}}, - - {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, - {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, - - {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, - {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, - - {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, - - {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, - - {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, - {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, - {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", - []testItem{{"human", []Human{}, []Human{ - {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, - {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, - }}}}, - - { - "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", - []testItem{ - {"id", 0, 123}, - {"isok", false, true}, - {"ft", 0.0, 1.2}, - {"ol", []int{}, []int{1, 2}}, - {"ul", []string{}, []string{"str", "array"}}, - {"human", Human{}, Human{Nick: "astaxie"}}, - }, - }, - } - for _, c := range cases { - r, _ := http.NewRequest("GET", c.request, nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - - for _, item := range c.valueGp { - got := item.empty - err := beegoInput.Bind(&got, item.field) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, item.want) { - t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) - } - } - - } -} - -func TestSubDomain(t *testing.T) { - r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - - subdomain := beegoInput.SubDomains() - if subdomain != "www" { - t.Fatal("Subdomain parse error, got" + subdomain) - } - - r, _ = http.NewRequest("GET", "http://localhost/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) - } - - r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "aa.bb" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - - /* TODO Fix this - r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - */ - - r, _ = http.NewRequest("GET", "http://example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - - r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "aa.bb.cc.dd" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } -} - -func TestParams(t *testing.T) { - inp := NewInput() - - inp.SetParam("p1", "val1_ver1") - inp.SetParam("p2", "val2_ver1") - inp.SetParam("p3", "val3_ver1") - if l := inp.ParamsLen(); l != 3 { - t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) - } - - if val := inp.Param("p1"); val != "val1_ver1" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") - } - if val := inp.Param("p3"); val != "val3_ver1" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") - } - vals := inp.Params() - expected := map[string]string{ - "p1": "val1_ver1", - "p2": "val2_ver1", - "p3": "val3_ver1", - } - if !reflect.DeepEqual(vals, expected) { - t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) - } - - // overwriting existing params - inp.SetParam("p1", "val1_ver2") - inp.SetParam("p2", "val2_ver2") - expected = map[string]string{ - "p1": "val1_ver2", - "p2": "val2_ver2", - "p3": "val3_ver1", - } - vals = inp.Params() - if !reflect.DeepEqual(vals, expected) { - t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) - } - - if l := inp.ParamsLen(); l != 3 { - t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) - } - - if val := inp.Param("p1"); val != "val1_ver2" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") - } - - if val := inp.Param("p2"); val != "val2_ver2" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") - } - -} -func BenchmarkQuery(b *testing.B) { - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - beegoInput.Query("q") - } - }) -} diff --git a/context/output.go b/context/output.go deleted file mode 100644 index 7409e4e5..00000000 --- a/context/output.go +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "html/template" - "io" - "mime" - "net/http" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - yaml "gopkg.in/yaml.v2" -) - -// BeegoOutput does work for sending response header. -type BeegoOutput struct { - Context *Context - Status int - EnableGzip bool -} - -// NewOutput returns new BeegoOutput. -// it contains nothing now. -func NewOutput() *BeegoOutput { - return &BeegoOutput{} -} - -// Reset init BeegoOutput -func (output *BeegoOutput) Reset(ctx *Context) { - output.Context = ctx - output.Clear() -} - -// Clear: clear all data in output -func (output *BeegoOutput) Clear() { - output.Status = 0 -} - -// Header sets response header item string via given key. -func (output *BeegoOutput) Header(key, val string) { - output.Context.ResponseWriter.Header().Set(key, val) -} - -// Body sets response body content. -// if EnableGzip, compress content string. -// it sends out response body directly. -func (output *BeegoOutput) Body(content []byte) error { - var encoding string - var buf = &bytes.Buffer{} - if output.EnableGzip { - encoding = ParseEncoding(output.Context.Request) - } - if b, n, _ := WriteBody(encoding, buf, content); b { - output.Header("Content-Encoding", n) - output.Header("Content-Length", strconv.Itoa(buf.Len())) - } else { - output.Header("Content-Length", strconv.Itoa(len(content))) - } - // Write status code if it has been set manually - // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" - if output.Status != 0 { - output.Context.ResponseWriter.WriteHeader(output.Status) - output.Status = 0 - } else { - output.Context.ResponseWriter.Started = true - } - io.Copy(output.Context.ResponseWriter, buf) - return nil -} - -// Cookie sets cookie value via given key. -// others are ordered as cookie's max age time, path,domain, secure and httponly. -func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { - var b bytes.Buffer - fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) - - //fix cookie not work in IE - if len(others) > 0 { - var maxAge int64 - - switch v := others[0].(type) { - case int: - maxAge = int64(v) - case int32: - maxAge = int64(v) - case int64: - maxAge = v - } - - switch { - case maxAge > 0: - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) - case maxAge < 0: - fmt.Fprintf(&b, "; Max-Age=0") - } - } - - // the settings below - // Path, Domain, Secure, HttpOnly - // can use nil skip set - - // default "/" - if len(others) > 1 { - if v, ok := others[1].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) - } - } else { - fmt.Fprintf(&b, "; Path=%s", "/") - } - - // default empty - if len(others) > 2 { - if v, ok := others[2].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) - } - } - - // default empty - if len(others) > 3 { - var secure bool - switch v := others[3].(type) { - case bool: - secure = v - default: - if others[3] != nil { - secure = true - } - } - if secure { - fmt.Fprintf(&b, "; Secure") - } - } - - // default false. for session cookie default true - if len(others) > 4 { - if v, ok := others[4].(bool); ok && v { - fmt.Fprintf(&b, "; HttpOnly") - } - } - - output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) -} - -var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") - -func sanitizeName(n string) string { - return cookieNameSanitizer.Replace(n) -} - -var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") - -func sanitizeValue(v string) string { - return cookieValueSanitizer.Replace(v) -} - -func jsonRenderer(value interface{}) Renderer { - return rendererFunc(func(ctx *Context) { - ctx.Output.JSON(value, false, false) - }) -} - -func errorRenderer(err error) Renderer { - return rendererFunc(func(ctx *Context) { - ctx.Output.SetStatus(500) - ctx.Output.Body([]byte(err.Error())) - }) -} - -// JSON writes json to response body. -// if encoding is true, it converts utf-8 to \u0000 type. -func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { - output.Header("Content-Type", "application/json; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = json.MarshalIndent(data, "", " ") - } else { - content, err = json.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - if encoding { - content = []byte(stringsToJSON(string(content))) - } - return output.Body(content) -} - -// YAML writes yaml to response body. -func (output *BeegoOutput) YAML(data interface{}) error { - output.Header("Content-Type", "application/x-yaml; charset=utf-8") - var content []byte - var err error - content, err = yaml.Marshal(data) - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - return output.Body(content) -} - -// JSONP writes jsonp to response body. -func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { - output.Header("Content-Type", "application/javascript; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = json.MarshalIndent(data, "", " ") - } else { - content, err = json.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - callback := output.Context.Input.Query("callback") - if callback == "" { - return errors.New(`"callback" parameter required`) - } - callback = template.JSEscapeString(callback) - callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) - callbackContent.WriteString("(") - callbackContent.Write(content) - callbackContent.WriteString(");\r\n") - return output.Body(callbackContent.Bytes()) -} - -// XML writes xml string to response body. -func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { - output.Header("Content-Type", "application/xml; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = xml.MarshalIndent(data, "", " ") - } else { - content, err = xml.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - return output.Body(content) -} - -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { - accept := output.Context.Input.Header("Accept") - switch accept { - case ApplicationYAML: - output.YAML(data) - case ApplicationXML, TextXML: - output.XML(data, hasIndent) - default: - output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) - } -} - -// Download forces response for download file. -// it prepares the download response header automatically. -func (output *BeegoOutput) Download(file string, filename ...string) { - // check get file error, file not found or other error. - if _, err := os.Stat(file); err != nil { - http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) - return - } - - var fName string - if len(filename) > 0 && filename[0] != "" { - fName = filename[0] - } else { - fName = filepath.Base(file) - } - //https://tools.ietf.org/html/rfc6266#section-4.3 - fn := url.PathEscape(fName) - if fName == fn { - fn = "filename=" + fn - } else { - /** - The parameters "filename" and "filename*" differ only in that - "filename*" uses the encoding defined in [RFC5987], allowing the use - of characters not present in the ISO-8859-1 character set - ([ISO-8859-1]). - */ - fn = "filename=" + fName + "; filename*=utf-8''" + fn - } - output.Header("Content-Disposition", "attachment; "+fn) - output.Header("Content-Description", "File Transfer") - output.Header("Content-Type", "application/octet-stream") - output.Header("Content-Transfer-Encoding", "binary") - output.Header("Expires", "0") - output.Header("Cache-Control", "must-revalidate") - output.Header("Pragma", "public") - http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) -} - -// ContentType sets the content type from ext string. -// MIME type is given in mime package. -func (output *BeegoOutput) ContentType(ext string) { - if !strings.HasPrefix(ext, ".") { - ext = "." + ext - } - ctype := mime.TypeByExtension(ext) - if ctype != "" { - output.Header("Content-Type", ctype) - } -} - -// SetStatus sets response status code. -// It writes response header directly. -func (output *BeegoOutput) SetStatus(status int) { - output.Status = status -} - -// IsCachable returns boolean of this request is cached. -// HTTP 304 means cached. -func (output *BeegoOutput) IsCachable() bool { - return output.Status >= 200 && output.Status < 300 || output.Status == 304 -} - -// IsEmpty returns boolean of this request is empty. -// HTTP 201,204 and 304 means empty. -func (output *BeegoOutput) IsEmpty() bool { - return output.Status == 201 || output.Status == 204 || output.Status == 304 -} - -// IsOk returns boolean of this request runs well. -// HTTP 200 means ok. -func (output *BeegoOutput) IsOk() bool { - return output.Status == 200 -} - -// IsSuccessful returns boolean of this request runs successfully. -// HTTP 2xx means ok. -func (output *BeegoOutput) IsSuccessful() bool { - return output.Status >= 200 && output.Status < 300 -} - -// IsRedirect returns boolean of this request is redirection header. -// HTTP 301,302,307 means redirection. -func (output *BeegoOutput) IsRedirect() bool { - return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 -} - -// IsForbidden returns boolean of this request is forbidden. -// HTTP 403 means forbidden. -func (output *BeegoOutput) IsForbidden() bool { - return output.Status == 403 -} - -// IsNotFound returns boolean of this request is not found. -// HTTP 404 means not found. -func (output *BeegoOutput) IsNotFound() bool { - return output.Status == 404 -} - -// IsClientError returns boolean of this request client sends error data. -// HTTP 4xx means client error. -func (output *BeegoOutput) IsClientError() bool { - return output.Status >= 400 && output.Status < 500 -} - -// IsServerError returns boolean of this server handler errors. -// HTTP 5xx means server internal error. -func (output *BeegoOutput) IsServerError() bool { - return output.Status >= 500 && output.Status < 600 -} - -func stringsToJSON(str string) string { - var jsons bytes.Buffer - for _, r := range str { - rint := int(r) - if rint < 128 { - jsons.WriteRune(r) - } else { - jsons.WriteString("\\u") - if rint < 0x100 { - jsons.WriteString("00") - } else if rint < 0x1000 { - jsons.WriteString("0") - } - jsons.WriteString(strconv.FormatInt(int64(rint), 16)) - } - } - return jsons.String() -} - -// Session sets session item value with given key. -func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) -} diff --git a/context/param/conv.go b/context/param/conv.go deleted file mode 100644 index c200e008..00000000 --- a/context/param/conv.go +++ /dev/null @@ -1,78 +0,0 @@ -package param - -import ( - "fmt" - "reflect" - - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" -) - -// ConvertParams converts http method params to values that will be passed to the method controller as arguments -func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { - result = make([]reflect.Value, 0, len(methodParams)) - for i := 0; i < len(methodParams); i++ { - reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) - result = append(result, reflectValue) - } - return -} - -func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { - paramValue := getParamValue(param, ctx) - if paramValue == "" { - if param.required { - ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) - } else { - paramValue = param.defaultValue - } - } - - reflectValue, err := parseValue(param, paramValue, paramType) - if err != nil { - logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) - ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) - } - - return reflectValue -} - -func getParamValue(param *MethodParam, ctx *beecontext.Context) string { - switch param.in { - case body: - return string(ctx.Input.RequestBody) - case header: - return ctx.Input.Header(param.name) - case path: - return ctx.Input.Query(":" + param.name) - default: - return ctx.Input.Query(param.name) - } -} - -func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { - if paramValue == "" { - return reflect.Zero(paramType), nil - } - parser := getParser(param, paramType) - value, err := parser.parse(paramValue, paramType) - if err != nil { - return result, err - } - - return safeConvert(reflect.ValueOf(value), paramType) -} - -func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { - defer func() { - if r := recover(); r != nil { - var ok bool - err, ok = r.(error) - if !ok { - err = fmt.Errorf("%v", r) - } - } - }() - result = value.Convert(t) - return -} diff --git a/context/param/methodparams.go b/context/param/methodparams.go deleted file mode 100644 index cd6708a2..00000000 --- a/context/param/methodparams.go +++ /dev/null @@ -1,69 +0,0 @@ -package param - -import ( - "fmt" - "strings" -) - -//MethodParam keeps param information to be auto passed to controller methods -type MethodParam struct { - name string - in paramType - required bool - defaultValue string -} - -type paramType byte - -const ( - param paramType = iota - path - body - header -) - -//New creates a new MethodParam with name and specific options -func New(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, nil, opts) -} - -func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { - param = &MethodParam{name: name} - for _, option := range opts { - option(param) - } - return -} - -//Make creates an array of MethodParmas or an empty array -func Make(list ...*MethodParam) []*MethodParam { - if len(list) > 0 { - return list - } - return nil -} - -func (mp *MethodParam) String() string { - options := []string{} - result := "param.New(\"" + mp.name + "\"" - if mp.required { - options = append(options, "param.IsRequired") - } - switch mp.in { - case path: - options = append(options, "param.InPath") - case body: - options = append(options, "param.InBody") - case header: - options = append(options, "param.InHeader") - } - if mp.defaultValue != "" { - options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) - } - if len(options) > 0 { - result += ", " - } - result += strings.Join(options, ", ") - result += ")" - return result -} diff --git a/context/param/options.go b/context/param/options.go deleted file mode 100644 index 3d5ba013..00000000 --- a/context/param/options.go +++ /dev/null @@ -1,37 +0,0 @@ -package param - -import ( - "fmt" -) - -// MethodParamOption defines a func which apply options on a MethodParam -type MethodParamOption func(*MethodParam) - -// IsRequired indicates that this param is required and can not be omitted from the http request -var IsRequired MethodParamOption = func(p *MethodParam) { - p.required = true -} - -// InHeader indicates that this param is passed via an http header -var InHeader MethodParamOption = func(p *MethodParam) { - p.in = header -} - -// InPath indicates that this param is part of the URL path -var InPath MethodParamOption = func(p *MethodParam) { - p.in = path -} - -// InBody indicates that this param is passed as an http request body -var InBody MethodParamOption = func(p *MethodParam) { - p.in = body -} - -// Default provides a default value for the http param -func Default(defaultValue interface{}) MethodParamOption { - return func(p *MethodParam) { - if defaultValue != nil { - p.defaultValue = fmt.Sprint(defaultValue) - } - } -} diff --git a/context/param/parsers.go b/context/param/parsers.go deleted file mode 100644 index 421aecf0..00000000 --- a/context/param/parsers.go +++ /dev/null @@ -1,149 +0,0 @@ -package param - -import ( - "encoding/json" - "reflect" - "strconv" - "strings" - "time" -) - -type paramParser interface { - parse(value string, toType reflect.Type) (interface{}, error) -} - -func getParser(param *MethodParam, t reflect.Type) paramParser { - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return intParser{} - case reflect.Slice: - if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string - return stringParser{} - } - if param.in == body { - return jsonParser{} - } - elemParser := getParser(param, t.Elem()) - if elemParser == (jsonParser{}) { - return elemParser - } - return sliceParser(elemParser) - case reflect.Bool: - return boolParser{} - case reflect.String: - return stringParser{} - case reflect.Float32, reflect.Float64: - return floatParser{} - case reflect.Ptr: - elemParser := getParser(param, t.Elem()) - if elemParser == (jsonParser{}) { - return elemParser - } - return ptrParser(elemParser) - default: - if t.PkgPath() == "time" && t.Name() == "Time" { - return timeParser{} - } - return jsonParser{} - } -} - -type parserFunc func(value string, toType reflect.Type) (interface{}, error) - -func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { - return f(value, toType) -} - -type boolParser struct { -} - -func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { - return strconv.ParseBool(value) -} - -type stringParser struct { -} - -func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { - return value, nil -} - -type intParser struct { -} - -func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { - return strconv.Atoi(value) -} - -type floatParser struct { -} - -func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { - if toType.Kind() == reflect.Float32 { - res, err := strconv.ParseFloat(value, 32) - if err != nil { - return nil, err - } - return float32(res), nil - } - return strconv.ParseFloat(value, 64) -} - -type timeParser struct { -} - -func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { - result, err = time.Parse(time.RFC3339, value) - if err != nil { - result, err = time.Parse("2006-01-02", value) - } - return -} - -type jsonParser struct { -} - -func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { - pResult := reflect.New(toType) - v := pResult.Interface() - err := json.Unmarshal([]byte(value), v) - if err != nil { - return nil, err - } - return pResult.Elem().Interface(), nil -} - -func sliceParser(elemParser paramParser) paramParser { - return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { - values := strings.Split(value, ",") - result := reflect.MakeSlice(toType, 0, len(values)) - elemType := toType.Elem() - for _, v := range values { - parsedValue, err := elemParser.parse(v, elemType) - if err != nil { - return nil, err - } - result = reflect.Append(result, reflect.ValueOf(parsedValue)) - } - return result.Interface(), nil - }) -} - -func ptrParser(elemParser paramParser) paramParser { - return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { - parsedValue, err := elemParser.parse(value, toType.Elem()) - if err != nil { - return nil, err - } - newValPtr := reflect.New(toType.Elem()) - newVal := reflect.Indirect(newValPtr) - convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) - if err != nil { - return nil, err - } - - newVal.Set(convertedVal) - return newValPtr.Interface(), nil - }) -} diff --git a/context/param/parsers_test.go b/context/param/parsers_test.go deleted file mode 100644 index 7065a28e..00000000 --- a/context/param/parsers_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package param - -import "testing" -import "reflect" -import "time" - -type testDefinition struct { - strValue string - expectedValue interface{} - expectedParser paramParser -} - -func Test_Parsers(t *testing.T) { - - //ints - checkParser(testDefinition{"1", 1, intParser{}}, t) - checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) - checkParser(testDefinition{"1", uint64(1), intParser{}}, t) - - //floats - checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) - checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) - - //strings - checkParser(testDefinition{"AB", "AB", stringParser{}}, t) - checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) - - //bools - checkParser(testDefinition{"true", true, boolParser{}}, t) - checkParser(testDefinition{"0", false, boolParser{}}, t) - - //timeParser - checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) - checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) - - //json - checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { - X int - Y string - }{5, "Z"}, jsonParser{}}, t) - - //slice in query is parsed as comma delimited - checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) - - //slice in body is parsed as json - checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) - - //pointers - var someInt = 1 - checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) - - var someStruct = struct{ X int }{5} - checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) - -} - -func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { - toType := reflect.TypeOf(def.expectedValue) - var mp MethodParam - if len(methodParam) == 0 { - mp = MethodParam{} - } else { - mp = methodParam[0] - } - parser := getParser(&mp, toType) - - if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { - t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) - return - } - result, err := parser.parse(def.strValue, toType) - if err != nil { - t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) - return - } - convResult, err := safeConvert(reflect.ValueOf(result), toType) - if err != nil { - t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) - return - } - if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { - t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) - } -} diff --git a/context/renderer.go b/context/renderer.go deleted file mode 100644 index 36a7cb53..00000000 --- a/context/renderer.go +++ /dev/null @@ -1,12 +0,0 @@ -package context - -// Renderer defines an http response renderer -type Renderer interface { - Render(ctx *Context) -} - -type rendererFunc func(ctx *Context) - -func (f rendererFunc) Render(ctx *Context) { - f(ctx) -} diff --git a/context/response.go b/context/response.go deleted file mode 100644 index 9c3c715a..00000000 --- a/context/response.go +++ /dev/null @@ -1,27 +0,0 @@ -package context - -import ( - "strconv" - - "net/http" -) - -const ( - //BadRequest indicates http error 400 - BadRequest StatusCode = http.StatusBadRequest - - //NotFound indicates http error 404 - NotFound StatusCode = http.StatusNotFound -) - -// StatusCode sets the http response status code -type StatusCode int - -func (s StatusCode) Error() string { - return strconv.Itoa(int(s)) -} - -// Render sets the http status code -func (s StatusCode) Render(ctx *Context) { - ctx.Output.SetStatus(int(s)) -} diff --git a/controller.go b/controller.go deleted file mode 100644 index 0b4a79a8..00000000 --- a/controller.go +++ /dev/null @@ -1,774 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "errors" - "fmt" - "html/template" - "io" - "mime/multipart" - "net/http" - "net/url" - "os" - "reflect" - "strconv" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/session" -) - -var ( - // ErrAbort custom error when user stop request handler manually. - // Deprecated: using pkg/, we will delete this in v2.1.0 - ErrAbort = errors.New("user stop run") - // GlobalControllerRouter store comments with controller. pkgpath+controller:comments - // Deprecated: using pkg/, we will delete this in v2.1.0 - GlobalControllerRouter = make(map[string][]ControllerComments) -) - -// ControllerFilter store the filter for controller -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerFilter struct { - Pattern string - Pos int - Filter FilterFunc - ReturnOnOutput bool - ResetParams bool -} - -// ControllerFilterComments store the comment for controller level filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerFilterComments struct { - Pattern string - Pos int - Filter string // NOQA - ReturnOnOutput bool - ResetParams bool -} - -// ControllerImportComments store the import comment for controller needed -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerImportComments struct { - ImportPath string - ImportAlias string -} - -// ControllerComments store the comment for the controller method -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerComments struct { - Method string - Router string - Filters []*ControllerFilter - ImportComments []*ControllerImportComments - FilterComments []*ControllerFilterComments - AllowHTTPMethods []string - Params []map[string]string - MethodParams []*param.MethodParam -} - -// ControllerCommentsSlice implements the sort interface -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerCommentsSlice []ControllerComments - -func (p ControllerCommentsSlice) Len() int { return len(p) } -func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } -func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } - -// Controller defines some basic http request handler operations, such as -// http context, template and view, session and xsrf. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Controller struct { - // context data - Ctx *context.Context - Data map[interface{}]interface{} - - // route controller info - controllerName string - actionName string - methodMapping map[string]func() //method:routertree - AppController interface{} - - // template data - TplName string - ViewPath string - Layout string - LayoutSections map[string]string // the key is the section name and the value is the template name - TplPrefix string - TplExt string - EnableRender bool - - // xsrf data - _xsrfToken string - XSRFExpire int - EnableXSRF bool - - // session - CruSession session.Store -} - -// ControllerInterface is an interface to uniform all controller handler. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerInterface interface { - Init(ct *context.Context, controllerName, actionName string, app interface{}) - Prepare() - Get() - Post() - Delete() - Put() - Head() - Patch() - Options() - Trace() - Finish() - Render() error - XSRFToken() string - CheckXSRFCookie() bool - HandlerFunc(fn string) bool - URLMapping() -} - -// Init generates default values of controller operations. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { - c.Layout = "" - c.TplName = "" - c.controllerName = controllerName - c.actionName = actionName - c.Ctx = ctx - c.TplExt = "tpl" - c.AppController = app - c.EnableRender = true - c.EnableXSRF = true - c.Data = ctx.Input.Data() - c.methodMapping = make(map[string]func()) -} - -// Prepare runs after Init before request function execution. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Prepare() {} - -// Finish runs after request function execution. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Finish() {} - -// Get adds a request function to handle GET request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Get() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Post adds a request function to handle POST request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Post() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Delete adds a request function to handle DELETE request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Delete() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Put adds a request function to handle PUT request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Put() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Head adds a request function to handle HEAD request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Head() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Patch adds a request function to handle PATCH request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Patch() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Options adds a request function to handle OPTIONS request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Options() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Trace adds a request function to handle Trace request. -// this method SHOULD NOT be overridden. -// https://tools.ietf.org/html/rfc7231#section-4.3.8 -// The TRACE method requests a remote, application-level loop-back of -// the request message. The final recipient of the request SHOULD -// reflect the message received, excluding some fields described below, -// back to the client as the message body of a 200 (OK) response with a -// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Trace() { - ts := func(h http.Header) (hs string) { - for k, v := range h { - hs += fmt.Sprintf("\r\n%s: %s", k, v) - } - return - } - hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) - c.Ctx.Output.Header("Content-Type", "message/http") - c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) - c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") - c.Ctx.WriteString(hs) -} - -// HandlerFunc call function with the name -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) HandlerFunc(fnname string) bool { - if v, ok := c.methodMapping[fnname]; ok { - v() - return true - } - return false -} - -// URLMapping register the internal Controller router. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) URLMapping() {} - -// Mapping the method to function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Mapping(method string, fn func()) { - c.methodMapping[method] = fn -} - -// Render sends the response with rendered template bytes as text/html type. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Render() error { - if !c.EnableRender { - return nil - } - rb, err := c.RenderBytes() - if err != nil { - return err - } - - if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { - c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") - } - - return c.Ctx.Output.Body(rb) -} - -// RenderString returns the rendered template string. Do not send out response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) RenderString() (string, error) { - b, e := c.RenderBytes() - return string(b), e -} - -// RenderBytes returns the bytes of rendered template string. Do not send out response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) RenderBytes() ([]byte, error) { - buf, err := c.renderTemplate() - //if the controller has set layout, then first get the tplName's content set the content to the layout - if err == nil && c.Layout != "" { - c.Data["LayoutContent"] = template.HTML(buf.String()) - - if c.LayoutSections != nil { - for sectionName, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - c.Data[sectionName] = "" - continue - } - buf.Reset() - err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) - if err != nil { - return nil, err - } - c.Data[sectionName] = template.HTML(buf.String()) - } - } - - buf.Reset() - ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) - } - return buf.Bytes(), err -} - -func (c *Controller) renderTemplate() (bytes.Buffer, error) { - var buf bytes.Buffer - if c.TplName == "" { - c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt - } - if c.TplPrefix != "" { - c.TplName = c.TplPrefix + c.TplName - } - if BConfig.RunMode == DEV { - buildFiles := []string{c.TplName} - if c.Layout != "" { - buildFiles = append(buildFiles, c.Layout) - if c.LayoutSections != nil { - for _, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - continue - } - buildFiles = append(buildFiles, sectionTpl) - } - } - } - BuildTemplate(c.viewPath(), buildFiles...) - } - return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) -} - -func (c *Controller) viewPath() string { - if c.ViewPath == "" { - return BConfig.WebConfig.ViewsPath - } - return c.ViewPath -} - -// Redirect sends the redirection response to url with status code. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Redirect(url string, code int) { - LogAccess(c.Ctx, nil, code) - c.Ctx.Redirect(code, url) -} - -// SetData set the data depending on the accepted -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetData(data interface{}) { - accept := c.Ctx.Input.Header("Accept") - switch accept { - case context.ApplicationYAML: - c.Data["yaml"] = data - case context.ApplicationXML, context.TextXML: - c.Data["xml"] = data - default: - c.Data["json"] = data - } -} - -// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Abort(code string) { - status, err := strconv.Atoi(code) - if err != nil { - status = 200 - } - c.CustomAbort(status, code) -} - -// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) CustomAbort(status int, body string) { - // first panic from ErrorMaps, it is user defined error functions. - if _, ok := ErrorMaps[body]; ok { - c.Ctx.Output.Status = status - panic(body) - } - // last panic user string - c.Ctx.ResponseWriter.WriteHeader(status) - c.Ctx.ResponseWriter.Write([]byte(body)) - panic(ErrAbort) -} - -// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) StopRun() { - panic(ErrAbort) -} - -// URLFor does another controller handler in this request function. -// it goes to this controller method if endpoint is not clear. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) URLFor(endpoint string, values ...interface{}) string { - if len(endpoint) == 0 { - return "" - } - if endpoint[0] == '.' { - return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) - } - return URLFor(endpoint, values...) -} - -// ServeJSON sends a json response with encoding charset. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeJSON(encoding ...bool) { - var ( - hasIndent = BConfig.RunMode != PROD - hasEncoding = len(encoding) > 0 && encoding[0] - ) - - c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) -} - -// ServeJSONP sends a jsonp response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeJSONP() { - hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) -} - -// ServeXML sends xml response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeXML() { - hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.XML(c.Data["xml"], hasIndent) -} - -// ServeYAML sends yaml response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeYAML() { - c.Ctx.Output.YAML(c.Data["yaml"]) -} - -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeFormatted(encoding ...bool) { - hasIndent := BConfig.RunMode != PROD - hasEncoding := len(encoding) > 0 && encoding[0] - c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) -} - -// Input returns the input data map from POST or PUT request body and query string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Input() url.Values { - if c.Ctx.Request.Form == nil { - c.Ctx.Request.ParseForm() - } - return c.Ctx.Request.Form -} - -// ParseForm maps input data map to obj struct. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ParseForm(obj interface{}) error { - return ParseForm(c.Input(), obj) -} - -// GetString returns the input value by key string or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetString(key string, def ...string) string { - if v := c.Ctx.Input.Query(key); v != "" { - return v - } - if len(def) > 0 { - return def[0] - } - return "" -} - -// GetStrings returns the input string slice by key string or the default value while it's present and input is blank -// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetStrings(key string, def ...[]string) []string { - var defv []string - if len(def) > 0 { - defv = def[0] - } - - if f := c.Input(); f == nil { - return defv - } else if vs := f[key]; len(vs) > 0 { - return vs - } - - return defv -} - -// GetInt returns input as an int or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt(key string, def ...int) (int, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.Atoi(strv) -} - -// GetInt8 return input as an int8 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 8) - return int8(i64), err -} - -// GetUint8 return input as an uint8 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 8) - return uint8(u64), err -} - -// GetInt16 returns input as an int16 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 16) - return int16(i64), err -} - -// GetUint16 returns input as an uint16 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 16) - return uint16(u64), err -} - -// GetInt32 returns input as an int32 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 32) - return int32(i64), err -} - -// GetUint32 returns input as an uint32 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 32) - return uint32(u64), err -} - -// GetInt64 returns input value as int64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseInt(strv, 10, 64) -} - -// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseUint(strv, 10, 64) -} - -// GetBool returns input value as bool or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetBool(key string, def ...bool) (bool, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseBool(strv) -} - -// GetFloat returns input value as float64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseFloat(strv, 64) -} - -// GetFile returns the file data in file upload field named as key. -// it returns the first one of multi-uploaded files. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { - return c.Ctx.Request.FormFile(key) -} - -// GetFiles return multi-upload files -// files, err:=c.GetFiles("myfiles") -// if err != nil { -// http.Error(w, err.Error(), http.StatusNoContent) -// return -// } -// for i, _ := range files { -// //for each fileheader, get a handle to the actual file -// file, err := files[i].Open() -// defer file.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //create destination file making sure the path is writeable. -// dst, err := os.Create("upload/" + files[i].Filename) -// defer dst.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //copy the uploaded file to the destination file -// if _, err := io.Copy(dst, file); err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// } -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { - if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { - return files, nil - } - return nil, http.ErrMissingFile -} - -// SaveToFile saves uploaded file to new path. -// it only operates the first one of mutil-upload form file field. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SaveToFile(fromfile, tofile string) error { - file, _, err := c.Ctx.Request.FormFile(fromfile) - if err != nil { - return err - } - defer file.Close() - f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - return err - } - defer f.Close() - io.Copy(f, file) - return nil -} - -// StartSession starts session and load old session data info this controller. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) StartSession() session.Store { - if c.CruSession == nil { - c.CruSession = c.Ctx.Input.CruSession - } - return c.CruSession -} - -// SetSession puts value into session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetSession(name interface{}, value interface{}) { - if c.CruSession == nil { - c.StartSession() - } - c.CruSession.Set(name, value) -} - -// GetSession gets value from session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetSession(name interface{}) interface{} { - if c.CruSession == nil { - c.StartSession() - } - return c.CruSession.Get(name) -} - -// DelSession removes value from session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) DelSession(name interface{}) { - if c.CruSession == nil { - c.StartSession() - } - c.CruSession.Delete(name) -} - -// SessionRegenerateID regenerates session id for this session. -// the session data have no changes. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SessionRegenerateID() { - if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) - } - c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) - c.Ctx.Input.CruSession = c.CruSession -} - -// DestroySession cleans session data and session cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() - c.Ctx.Input.CruSession = nil - GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) -} - -// IsAjax returns this request is ajax or not. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) IsAjax() bool { - return c.Ctx.Input.IsAjax() -} - -// GetSecureCookie returns decoded cookie value from encoded browser cookie values. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { - return c.Ctx.GetSecureCookie(Secret, key) -} - -// SetSecureCookie puts value into cookie after encoded the value. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { - c.Ctx.SetSecureCookie(Secret, name, value, others...) -} - -// XSRFToken creates a CSRF token string and returns. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) XSRFToken() string { - if c._xsrfToken == "" { - expire := int64(BConfig.WebConfig.XSRFExpire) - if c.XSRFExpire > 0 { - expire = int64(c.XSRFExpire) - } - c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) - } - return c._xsrfToken -} - -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" -// or in form field value named as "_xsrf". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) CheckXSRFCookie() bool { - if !c.EnableXSRF { - return true - } - return c.Ctx.CheckXSRFCookie() -} - -// XSRFFormHTML writes an input field contains xsrf token value. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) XSRFFormHTML() string { - return `` -} - -// GetControllerAndAction gets the executing controller name and action name. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetControllerAndAction() (string, string) { - return c.controllerName, c.actionName -} diff --git a/controller_test.go b/controller_test.go deleted file mode 100644 index 1e53416d..00000000 --- a/controller_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "math" - "strconv" - "testing" - - "github.com/astaxie/beego/context" - "os" - "path/filepath" -) - -func TestGetInt(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt("age") - if val != 40 { - t.Errorf("TestGetInt expect 40,get %T,%v", val, val) - } -} - -func TestGetInt8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt8("age") - if val != 40 { - t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) - } - //Output: int8 -} - -func TestGetInt16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt16("age") - if val != 40 { - t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt32("age") - if val != 40 { - t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt64("age") - if val != 40 { - t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) - } -} - -func TestGetUint8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint8("age") - if val != math.MaxUint8 { - t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) - } -} - -func TestGetUint16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint16("age") - if val != math.MaxUint16 { - t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) - } -} - -func TestGetUint32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint32("age") - if val != math.MaxUint32 { - t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) - } -} - -func TestGetUint64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint64("age") - if val != math.MaxUint64 { - t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) - } -} - -func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" - defer os.RemoveAll(dir1) - defer os.RemoveAll(dir2) - - dir1file := "file1.tpl" - dir2file := "file2.tpl" - - genFile := func(dir string, name string, content string) { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - defer f.Close() - f.WriteString(content) - f.Close() - } - - } - genFile(dir1, dir1file, `
{{.Content}}
`) - genFile(dir2, dir2file, `{{.Content}}`) - - AddViewPath(dir1) - AddViewPath(dir2) - - ctrl := Controller{ - TplName: "file1.tpl", - ViewPath: dir1, - } - ctrl.Data = map[interface{}]interface{}{ - "Content": "value2", - } - if result, err := ctrl.RenderString(); err != nil { - t.Fatal(err) - } else { - if result != "
value2
" { - t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) - } - } - - func() { - ctrl.TplName = "file2.tpl" - defer func() { - if r := recover(); r == nil { - t.Fatal("TestAdditionalViewPaths expected error") - } - }() - ctrl.RenderString() - }() - - ctrl.TplName = "file2.tpl" - ctrl.ViewPath = dir2 - ctrl.RenderString() -} diff --git a/doc.go b/doc.go deleted file mode 100644 index 72284c67..00000000 --- a/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -/* -Package beego provide a MVC framework -beego: an open-source, high-performance, modular, full-stack web framework - -It is used for rapid development of RESTful APIs, web apps and backend services in Go. -beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. - - package main - import "github.com/astaxie/beego" - - func main() { - beego.Run() - } - -more information: http://beego.me - -Deprecated: using pkg/, we will delete this in v2.1.0 -*/ -package beego diff --git a/error.go b/error.go deleted file mode 100644 index 40eea5fa..00000000 --- a/error.go +++ /dev/null @@ -1,492 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "fmt" - "html/template" - "net/http" - "reflect" - "runtime" - "strconv" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" -) - -const ( - errorTypeHandler = iota - errorTypeController -) - -var tpl = ` - - - - - beego application error - - - - - -
- - - - - - - - - - -
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
-
- Stack -
{{.Stack}}
-
-
- - - -` - -// render default application error page with error and stack string. -func showErr(err interface{}, ctx *context.Context, stack string) { - t, _ := template.New("beegoerrortemp").Parse(tpl) - data := map[string]string{ - "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), - "RequestMethod": ctx.Input.Method(), - "RequestURL": ctx.Input.URI(), - "RemoteAddr": ctx.Input.IP(), - "Stack": stack, - "BeegoVersion": VERSION, - "GoVersion": runtime.Version(), - } - t.Execute(ctx.ResponseWriter, data) -} - -var errtpl = ` - - - - - {{.Title}} - - - -
-
- -
- {{.Content}} - Go Home
- -
Powered by beego {{.BeegoVersion}} -
-
-
- - -` - -type errorInfo struct { - controllerType reflect.Type - handler http.HandlerFunc - method string - errorType int -} - -// ErrorMaps holds map of http handlers for each error string. -// there is 10 kinds default error(40x and 50x) -// Deprecated: using pkg/, we will delete this in v2.1.0 -var ErrorMaps = make(map[string]*errorInfo, 10) - -// show 401 unauthorized error. -func unauthorized(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 401, - "
The page you have requested can't be authorized."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The credentials you supplied are incorrect"+ - "
    There are errors in the website address"+ - "
", - ) -} - -// show 402 Payment Required -func paymentRequired(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 402, - "
The page you have requested Payment Required."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The credentials you supplied are incorrect"+ - "
    There are errors in the website address"+ - "
", - ) -} - -// show 403 forbidden error. -func forbidden(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 403, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    Your address may be blocked"+ - "
    The site may be disabled"+ - "
    You need to log in"+ - "
", - ) -} - -// show 422 missing xsrf token -func missingxsrf(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 422, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    '_xsrf' argument missing from POST"+ - "
", - ) -} - -// show 417 invalid xsrf token -func invalidxsrf(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 417, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    expected XSRF not found"+ - "
", - ) -} - -// show 404 not found error. -func notFound(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 404, - "
The page you have requested has flown the coop."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The page has moved"+ - "
    The page no longer exists"+ - "
    You were looking for your puppy and got lost"+ - "
    You like 404 pages"+ - "
", - ) -} - -// show 405 Method Not Allowed -func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 405, - "
The method you have requested Not Allowed."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ - "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ - "
", - ) -} - -// show 500 internal server error. -func internalServerError(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 500, - "
The page you have requested is down right now."+ - "

    "+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 501 Not Implemented. -func notImplemented(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 501, - "
The page you have requested is Not Implemented."+ - "

    "+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 502 Bad Gateway. -func badGateway(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 502, - "
The page you have requested is down right now."+ - "

    "+ - "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 503 service unavailable error. -func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 503, - "
The page you have requested is unavailable."+ - "
Perhaps you are here because:"+ - "

    "+ - "

    The page is overloaded"+ - "
    Please try again later."+ - "
", - ) -} - -// show 504 Gateway Timeout. -func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 504, - "
The page you have requested is unavailable"+ - "
Perhaps you are here because:"+ - "

    "+ - "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ - "
    Please try again later."+ - "
", - ) -} - -// show 413 Payload Too Large -func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 413, - `
The page you have requested is unavailable. -
Perhaps you are here because:

-
    -
    The request entity is larger than limits defined by server. -
    Please change the request entity and try again. -
- `, - ) -} - -func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := M{ - "Title": http.StatusText(errCode), - "BeegoVersion": VERSION, - "Content": template.HTML(errContent), - } - t.Execute(rw, data) -} - -// ErrorHandler registers http.HandlerFunc to each http err code string. -// usage: -// beego.ErrorHandler("404",NotFound) -// beego.ErrorHandler("500",InternalServerError) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ErrorHandler(code string, h http.HandlerFunc) *App { - ErrorMaps[code] = &errorInfo{ - errorType: errorTypeHandler, - handler: h, - method: code, - } - return BeeApp -} - -// ErrorController registers ControllerInterface to each http err code string. -// usage: -// beego.ErrorController(&controllers.ErrorController{}) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ErrorController(c ControllerInterface) *App { - reflectVal := reflect.ValueOf(c) - rt := reflectVal.Type() - ct := reflect.Indirect(reflectVal).Type() - for i := 0; i < rt.NumMethod(); i++ { - methodName := rt.Method(i).Name - if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { - errName := strings.TrimPrefix(methodName, "Error") - ErrorMaps[errName] = &errorInfo{ - errorType: errorTypeController, - controllerType: ct, - method: methodName, - } - } - } - return BeeApp -} - -// Exception Write HttpStatus with errCode and Exec error handler if exist. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Exception(errCode uint64, ctx *context.Context) { - exception(strconv.FormatUint(errCode, 10), ctx) -} - -// show error string as simple text message. -// if error string is empty, show 503 or 500 error as default. -func exception(errCode string, ctx *context.Context) { - atoi := func(code string) int { - v, err := strconv.Atoi(code) - if err == nil { - return v - } - if ctx.Output.Status == 0 { - return 503 - } - return ctx.Output.Status - } - - for _, ec := range []string{errCode, "503", "500"} { - if h, ok := ErrorMaps[ec]; ok { - executeError(h, ctx, atoi(ec)) - return - } - } - //if 50x error has been removed from errorMap - ctx.ResponseWriter.WriteHeader(atoi(errCode)) - ctx.WriteString(errCode) -} - -func executeError(err *errorInfo, ctx *context.Context, code int) { - //make sure to log the error in the access log - LogAccess(ctx, nil, code) - - if err.errorType == errorTypeHandler { - ctx.ResponseWriter.WriteHeader(code) - err.handler(ctx.ResponseWriter, ctx.Request) - return - } - if err.errorType == errorTypeController { - ctx.Output.SetStatus(code) - //Invoke the request handler - vc := reflect.New(err.controllerType) - execController, ok := vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - //call the controller init function - execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) - - //call prepare function - execController.Prepare() - - execController.URLMapping() - - method := vc.MethodByName(err.method) - method.Call([]reflect.Value{}) - - //render template - if BConfig.WebConfig.AutoRender { - if err := execController.Render(); err != nil { - panic(err) - } - } - - // finish all runrouter. release resource - execController.Finish() - } -} diff --git a/error_test.go b/error_test.go deleted file mode 100644 index 378aa953..00000000 --- a/error_test.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" -) - -type errorTestController struct { - Controller -} - -const parseCodeError = "parse code error" - -func (ec *errorTestController) Get() { - errorCode, err := ec.GetInt("code") - if err != nil { - ec.Abort(parseCodeError) - } - if errorCode != 0 { - ec.CustomAbort(errorCode, ec.GetString("code")) - } - ec.Abort("404") -} - -func TestErrorCode_01(t *testing.T) { - registerDefaultErrorHandler() - for k := range ErrorMaps { - r, _ := http.NewRequest("GET", "/error?code="+k, nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - code, _ := strconv.Atoi(k) - if w.Code != code { - t.Fail() - } - if !strings.Contains(w.Body.String(), http.StatusText(code)) { - t.Fail() - } - } -} - -func TestErrorCode_02(t *testing.T) { - registerDefaultErrorHandler() - r, _ := http.NewRequest("GET", "/error?code=0", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - if w.Code != 404 { - t.Fail() - } -} - -func TestErrorCode_03(t *testing.T) { - registerDefaultErrorHandler() - r, _ := http.NewRequest("GET", "/error?code=panic", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - if w.Code != 200 { - t.Fail() - } - if w.Body.String() != parseCodeError { - t.Fail() - } -} diff --git a/filter.go b/filter.go deleted file mode 100644 index 8596d288..00000000 --- a/filter.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import "github.com/astaxie/beego/context" - -// FilterFunc defines a filter function which is invoked before the controller handler is executed. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterFunc func(*context.Context) - -// FilterRouter defines a filter operation which is invoked before the controller handler is executed. -// It can match the URL against a pattern, and execute a filter function -// when a request with a matching URL arrives. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterRouter struct { - filterFunc FilterFunc - tree *Tree - pattern string - returnOnOutput bool - resetParams bool -} - -// ValidRouter checks if the current request is matched by this filter. -// If the request is matched, the values of the URL parameters defined -// by the filter pattern are also returned. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { - isOk := f.tree.Match(url, ctx) - if isOk != nil { - if b, ok := isOk.(bool); ok { - return b - } - } - return false -} diff --git a/filter_test.go b/filter_test.go deleted file mode 100644 index 4ca4d2b8..00000000 --- a/filter_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/astaxie/beego/context" -) - -var FilterUser = func(ctx *context.Context) { - ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) -} - -func TestFilter(t *testing.T) { - r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) - handler.Add("/person/:last/:first", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am astaXie" { - t.Errorf("user define func can't run") - } -} - -var FilterAdminUser = func(ctx *context.Context) { - ctx.Output.Body([]byte("i am admin")) -} - -// Filter pattern /admin/:all -// all url like /admin/ /admin/xie will all get filter - -func TestPatternTwo(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am admin" { - t.Errorf("filter /admin/ can't run") - } -} - -func TestPatternThree(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/astaxie", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am admin" { - t.Errorf("filter /admin/astaxie can't run") - } -} diff --git a/flash.go b/flash.go deleted file mode 100644 index fe3fb974..00000000 --- a/flash.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "fmt" - "net/url" - "strings" -) - -// FlashData is a tools to maintain data when using across request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FlashData struct { - Data map[string]string -} - -// NewFlash return a new empty FlashData struct. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewFlash() *FlashData { - return &FlashData{ - Data: make(map[string]string), - } -} - -// Set message to flash -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Set(key string, msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data[key] = msg - } else { - fd.Data[key] = fmt.Sprintf(msg, args...) - } -} - -// Success writes success message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Success(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["success"] = msg - } else { - fd.Data["success"] = fmt.Sprintf(msg, args...) - } -} - -// Notice writes notice message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Notice(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["notice"] = msg - } else { - fd.Data["notice"] = fmt.Sprintf(msg, args...) - } -} - -// Warning writes warning message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Warning(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["warning"] = msg - } else { - fd.Data["warning"] = fmt.Sprintf(msg, args...) - } -} - -// Error writes error message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Error(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["error"] = msg - } else { - fd.Data["error"] = fmt.Sprintf(msg, args...) - } -} - -// Store does the saving operation of flash data. -// the data are encoded and saved in cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Store(c *Controller) { - c.Data["flash"] = fd.Data - var flashValue string - for key, value := range fd.Data { - flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" - } - c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") -} - -// ReadFromRequest parsed flash data from encoded values in cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ReadFromRequest(c *Controller) *FlashData { - flash := NewFlash() - if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { - v, _ := url.QueryUnescape(cookie.Value) - vals := strings.Split(v, "\x00") - for _, v := range vals { - if len(v) > 0 { - kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") - if len(kv) == 2 { - flash.Data[kv[0]] = kv[1] - } - } - } - //read one time then delete it - c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") - } - c.Data["flash"] = flash.Data - return flash -} diff --git a/flash_test.go b/flash_test.go deleted file mode 100644 index d5e9608d..00000000 --- a/flash_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -type TestFlashController struct { - Controller -} - -func (t *TestFlashController) TestWriteFlash() { - flash := NewFlash() - flash.Notice("TestFlashString") - flash.Store(&t.Controller) - // we choose to serve json because we don't want to load a template html file - t.ServeJSON(true) -} - -func TestFlashHeader(t *testing.T) { - // create fake GET request - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - // setup the handler - handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") - handler.ServeHTTP(w, r) - - // get the Set-Cookie value - sc := w.Header().Get("Set-Cookie") - // match for the expected header - res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") - // validate the assertion - if !res { - t.Errorf("TestFlashHeader() unable to validate flash message") - } -} diff --git a/fs.go b/fs.go deleted file mode 100644 index 3300813d..00000000 --- a/fs.go +++ /dev/null @@ -1,77 +0,0 @@ -package beego - -import ( - "net/http" - "os" - "path/filepath" -) - -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FileSystem struct { -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (d FileSystem) Open(name string) (http.File, error) { - return os.Open(name) -} - -// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or -// directory in the tree, including root. All errors that arise visiting files -// and directories are filtered by walkFn. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { - - f, err := fs.Open(root) - if err != nil { - return err - } - info, err := f.Stat() - if err != nil { - err = walkFn(root, nil, err) - } else { - err = walk(fs, root, info, walkFn) - } - if err == filepath.SkipDir { - return nil - } - return err -} - -// walk recursively descends path, calling walkFn. -func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { - var err error - if !info.IsDir() { - return walkFn(path, info, nil) - } - - dir, err := fs.Open(path) - if err != nil { - if err1 := walkFn(path, info, err); err1 != nil { - return err1 - } - return err - } - defer dir.Close() - dirs, err := dir.Readdir(-1) - err1 := walkFn(path, info, err) - // If err != nil, walk can't walk into this directory. - // err1 != nil means walkFn want walk to skip this directory or stop walking. - // Therefore, if one of err and err1 isn't nil, walk will return. - if err != nil || err1 != nil { - // The caller's behavior is controlled by the return value, which is decided - // by walkFn. walkFn may ignore err and return nil. - // If walkFn returns SkipDir, it will be handled by the caller. - // So walk should return whatever walkFn returns. - return err1 - } - - for _, fileInfo := range dirs { - filename := filepath.Join(path, fileInfo.Name()) - if err = walk(fs, filename, fileInfo, walkFn); err != nil { - if !fileInfo.IsDir() || err != filepath.SkipDir { - return err - } - } - } - return nil -} diff --git a/go.mod b/go.mod index e1b9fcc2..91bd9aef 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( golang.org/x/tools v0.0.0-20200117065230-39095c1d176c google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 + honnef.co/go/tools v0.0.1-2020.1.5 // indirect ) replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 diff --git a/go.sum b/go.sum index 1666981d..95babc92 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= @@ -100,6 +101,7 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -159,6 +161,7 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 h1:X+yvsM2yrEktyI+b2qND5gpH8YhURn0k8OCaeRnkINo= github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg= github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0 h1:QIF48X1cihydXibm+4wfAc0r/qyPyuFiPFRNphdMpEE= @@ -182,16 +185,23 @@ github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2K github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8= github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -204,12 +214,15 @@ golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -220,6 +233,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= @@ -231,11 +245,18 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3 h1:0aScV/0rLmANzEYIhjCOi2pTvDyhZNduBUMD2q3iqs4= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -254,9 +275,11 @@ google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyz google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= @@ -268,4 +291,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= +honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= diff --git a/grace/grace.go b/grace/grace.go deleted file mode 100644 index 39d067fd..00000000 --- a/grace/grace.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package grace use to hot reload -// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ -// -// Usage: -// -// import( -// "log" -// "net/http" -// "os" -// -// "github.com/astaxie/beego/grace" -// ) -// -// func handler(w http.ResponseWriter, r *http.Request) { -// w.Write([]byte("WORLD!")) -// } -// -// func main() { -// mux := http.NewServeMux() -// mux.HandleFunc("/hello", handler) -// -// err := grace.ListenAndServe("localhost:8080", mux) -// if err != nil { -// log.Println(err) -// } -// log.Println("Server on 8080 stopped") -// os.Exit(0) -// } -package grace - -import ( - "flag" - "net/http" - "os" - "strings" - "sync" - "syscall" - "time" -) - -const ( - // PreSignal is the position to add filter before signal - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - PreSignal = iota - // PostSignal is the position to add filter after signal - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - PostSignal - // StateInit represent the application inited - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateInit - // StateRunning represent the application is running - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateRunning - // StateShuttingDown represent the application is shutting down - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateShuttingDown - // StateTerminate represent the application is killed - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateTerminate -) - -var ( - regLock *sync.Mutex - runningServers map[string]*Server - runningServersOrder []string - socketPtrOffsetMap map[string]uint - runningServersForked bool - - // DefaultReadTimeOut is the HTTP read timeout - DefaultReadTimeOut time.Duration - // DefaultWriteTimeOut is the HTTP Write timeout - DefaultWriteTimeOut time.Duration - // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit - DefaultMaxHeaderBytes int - // DefaultTimeout is the shutdown server's timeout. default is 60s - DefaultTimeout = 60 * time.Second - - isChild bool - socketOrder string - - hookableSignals []os.Signal -) - -func init() { - flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") - flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") - - regLock = &sync.Mutex{} - runningServers = make(map[string]*Server) - runningServersOrder = []string{} - socketPtrOffsetMap = make(map[string]uint) - - hookableSignals = []os.Signal{ - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - } -} - -// NewServer returns a new graceServer. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func NewServer(addr string, handler http.Handler) (srv *Server) { - regLock.Lock() - defer regLock.Unlock() - - if !flag.Parsed() { - flag.Parse() - } - if len(socketOrder) > 0 { - for i, addr := range strings.Split(socketOrder, ",") { - socketPtrOffsetMap[addr] = uint(i) - } - } else { - socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) - } - - srv = &Server{ - sigChan: make(chan os.Signal), - isChild: isChild, - SignalHooks: map[int]map[os.Signal][]func(){ - PreSignal: { - syscall.SIGHUP: {}, - syscall.SIGINT: {}, - syscall.SIGTERM: {}, - }, - PostSignal: { - syscall.SIGHUP: {}, - syscall.SIGINT: {}, - syscall.SIGTERM: {}, - }, - }, - state: StateInit, - Network: "tcp", - terminalChan: make(chan error), //no cache channel - } - srv.Server = &http.Server{ - Addr: addr, - ReadTimeout: DefaultReadTimeOut, - WriteTimeout: DefaultWriteTimeOut, - MaxHeaderBytes: DefaultMaxHeaderBytes, - Handler: handler, - } - - runningServersOrder = append(runningServersOrder, addr) - runningServers[addr] = srv - return srv -} - -// ListenAndServe refer http.ListenAndServe -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func ListenAndServe(addr string, handler http.Handler) error { - server := NewServer(addr, handler) - return server.ListenAndServe() -} - -// ListenAndServeTLS refer http.ListenAndServeTLS -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { - server := NewServer(addr, handler) - return server.ListenAndServeTLS(certFile, keyFile) -} diff --git a/grace/server.go b/grace/server.go deleted file mode 100644 index cd659f82..00000000 --- a/grace/server.go +++ /dev/null @@ -1,362 +0,0 @@ -package grace - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "log" - "net" - "net/http" - "os" - "os/exec" - "os/signal" - "strings" - "syscall" - "time" -) - -// Server embedded http.Server -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -type Server struct { - *http.Server - ln net.Listener - SignalHooks map[int]map[os.Signal][]func() - sigChan chan os.Signal - isChild bool - state uint8 - Network string - terminalChan chan error -} - -// Serve accepts incoming connections on the Listener l, -// creating a new service goroutine for each. -// The service goroutines read requests and then call srv.Handler to reply to them. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) Serve() (err error) { - srv.state = StateRunning - defer func() { srv.state = StateTerminate }() - - // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS - // immediately return ErrServerClosed. Make sure the program doesn't exit - // and waits instead for Shutdown to return. - if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { - log.Println(syscall.Getpid(), "Server.Serve() error:", err) - return err - } - - log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") - // wait for Shutdown to return - if shutdownErr := <-srv.terminalChan; shutdownErr != nil { - return shutdownErr - } - return -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve -// to handle requests on incoming connections. If srv.Addr is blank, ":http" is -// used. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServe() (err error) { - addr := srv.Addr - if addr == "" { - addr = ":http" - } - - go srv.handleSignals() - - srv.ln, err = srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming TLS connections. -// -// Filenames containing a certificate and matching private key for the server must -// be provided. If the certificate is signed by a certificate authority, the -// certFile should be the concatenation of the server's certificate followed by the -// CA's certificate. -// -// If srv.Addr is blank, ":https" is used. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { - addr := srv.Addr - if addr == "" { - addr = ":https" - } - - if srv.TLSConfig == nil { - srv.TLSConfig = &tls.Config{} - } - if srv.TLSConfig.NextProtos == nil { - srv.TLSConfig.NextProtos = []string{"http/1.1"} - } - - srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return - } - - go srv.handleSignals() - - ln, err := srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming mutual TLS connections. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { - addr := srv.Addr - if addr == "" { - addr = ":https" - } - - if srv.TLSConfig == nil { - srv.TLSConfig = &tls.Config{} - } - if srv.TLSConfig.NextProtos == nil { - srv.TLSConfig.NextProtos = []string{"http/1.1"} - } - - srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return - } - srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(trustFile) - if err != nil { - log.Println(err) - return err - } - pool.AppendCertsFromPEM(data) - srv.TLSConfig.ClientCAs = pool - log.Println("Mutual HTTPS") - go srv.handleSignals() - - ln, err := srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// getListener either opens a new socket to listen on, or takes the acceptor socket -// it got passed when restarted. -func (srv *Server) getListener(laddr string) (l net.Listener, err error) { - if srv.isChild { - var ptrOffset uint - if len(socketPtrOffsetMap) > 0 { - ptrOffset = socketPtrOffsetMap[laddr] - log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) - } - - f := os.NewFile(uintptr(3+ptrOffset), "") - l, err = net.FileListener(f) - if err != nil { - err = fmt.Errorf("net.FileListener error: %v", err) - return - } - } else { - l, err = net.Listen(srv.Network, laddr) - if err != nil { - err = fmt.Errorf("net.Listen error: %v", err) - return - } - } - return -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} - -// handleSignals listens for os Signals and calls any hooked in function that the -// user had registered with the signal. -func (srv *Server) handleSignals() { - var sig os.Signal - - signal.Notify( - srv.sigChan, - hookableSignals..., - ) - - pid := syscall.Getpid() - for { - sig = <-srv.sigChan - srv.signalHooks(PreSignal, sig) - switch sig { - case syscall.SIGHUP: - log.Println(pid, "Received SIGHUP. forking.") - err := srv.fork() - if err != nil { - log.Println("Fork err:", err) - } - case syscall.SIGINT: - log.Println(pid, "Received SIGINT.") - srv.shutdown() - case syscall.SIGTERM: - log.Println(pid, "Received SIGTERM.") - srv.shutdown() - default: - log.Printf("Received %v: nothing i care about...\n", sig) - } - srv.signalHooks(PostSignal, sig) - } -} - -func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { - if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { - return - } - for _, f := range srv.SignalHooks[ppFlag][sig] { - f() - } -} - -// shutdown closes the listener so that no new connections are accepted. it also -// starts a goroutine that will serverTimeout (stop all running requests) the server -// after DefaultTimeout. -func (srv *Server) shutdown() { - if srv.state != StateRunning { - return - } - - srv.state = StateShuttingDown - log.Println(syscall.Getpid(), "Waiting for connections to finish...") - ctx := context.Background() - if DefaultTimeout >= 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) - defer cancel() - } - srv.terminalChan <- srv.Server.Shutdown(ctx) -} - -func (srv *Server) fork() (err error) { - regLock.Lock() - defer regLock.Unlock() - if runningServersForked { - return - } - runningServersForked = true - - var files = make([]*os.File, len(runningServers)) - var orderArgs = make([]string, len(runningServers)) - for _, srvPtr := range runningServers { - f, _ := srvPtr.ln.(*net.TCPListener).File() - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f - orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr - } - - log.Println(files) - path := os.Args[0] - var args []string - if len(os.Args) > 1 { - for _, arg := range os.Args[1:] { - if arg == "-graceful" { - break - } - args = append(args, arg) - } - } - args = append(args, "-graceful") - if len(runningServers) > 1 { - args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) - log.Println(args) - } - cmd := exec.Command(path, args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.ExtraFiles = files - err = cmd.Start() - if err != nil { - log.Fatalf("Restart: Failed to launch, error: %v", err) - } - - return -} - -// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { - if ppFlag != PreSignal && ppFlag != PostSignal { - err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") - return - } - for _, s := range hookableSignals { - if s == sig { - srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) - return - } - } - err = fmt.Errorf("Signal '%v' is not supported", sig) - return -} diff --git a/hooks.go b/hooks.go deleted file mode 100644 index 49c42d5a..00000000 --- a/hooks.go +++ /dev/null @@ -1,104 +0,0 @@ -package beego - -import ( - "encoding/json" - "mime" - "net/http" - "path/filepath" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" -) - -// register MIME type with content type -func registerMime() error { - for k, v := range mimemaps { - mime.AddExtensionType(k, v) - } - return nil -} - -// register default error http handlers, 404,401,403,500 and 503. -func registerDefaultErrorHandler() error { - m := map[string]func(http.ResponseWriter, *http.Request){ - "401": unauthorized, - "402": paymentRequired, - "403": forbidden, - "404": notFound, - "405": methodNotAllowed, - "500": internalServerError, - "501": notImplemented, - "502": badGateway, - "503": serviceUnavailable, - "504": gatewayTimeout, - "417": invalidxsrf, - "422": missingxsrf, - "413": payloadTooLarge, - } - for e, h := range m { - if _, ok := ErrorMaps[e]; !ok { - ErrorHandler(e, h) - } - } - return nil -} - -func registerSession() error { - if BConfig.WebConfig.Session.SessionOn { - var err error - sessionConfig := AppConfig.String("sessionConfig") - conf := new(session.ManagerConfig) - if sessionConfig == "" { - conf.CookieName = BConfig.WebConfig.Session.SessionName - conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie - conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime - conf.Secure = BConfig.Listen.EnableHTTPS - conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime - conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) - conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly - conf.Domain = BConfig.WebConfig.Session.SessionDomain - conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader - conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader - conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery - } else { - if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { - return err - } - } - if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { - return err - } - go GlobalSessions.GC() - } - return nil -} - -func registerTemplate() error { - defer lockViewPaths() - if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { - if BConfig.RunMode == DEV { - logs.Warn(err) - } - return err - } - return nil -} - -func registerAdmin() error { - if BConfig.Listen.EnableAdmin { - go beeAdminApp.Run() - } - return nil -} - -func registerGzip() error { - if BConfig.EnableGzip { - context.InitGzip( - AppConfig.DefaultInt("gzipMinLength", -1), - AppConfig.DefaultInt("gzipCompressLevel", -1), - AppConfig.DefaultStrings("includedMethods", []string{"GET"}), - ) - } - return nil -} diff --git a/httplib/README.md b/httplib/README.md deleted file mode 100644 index 97df8e6b..00000000 --- a/httplib/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# httplib -httplib is an libs help you to curl remote url. - -# How to use? - -## GET -you can use Get to crawl data. - - import "github.com/astaxie/beego/httplib" - - str, err := httplib.Get("http://beego.me/").String() - if err != nil { - // error - } - fmt.Println(str) - -## POST -POST data to remote url - - req := httplib.Post("http://beego.me/") - req.Param("username","astaxie") - req.Param("password","123456") - str, err := req.String() - if err != nil { - // error - } - fmt.Println(str) - -## Set timeout - -The default timeout is `60` seconds, function prototype: - - SetTimeout(connectTimeout, readWriteTimeout time.Duration) - -Example: - - // GET - httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - - // POST - httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - - -## Debug - -If you want to debug the request info, set the debug on - - httplib.Get("http://beego.me/").Debug(true) - -## Set HTTP Basic Auth - - str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() - if err != nil { - // error - } - fmt.Println(str) - -## Set HTTPS - -If request url is https, You can set the client support TSL: - - httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) - -More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config - -## Set HTTP Version - -some servers need to specify the protocol version of HTTP - - httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") - -## Set Cookie - -some http request need setcookie. So set it like this: - - cookie := &http.Cookie{} - cookie.Name = "username" - cookie.Value = "astaxie" - httplib.Get("http://beego.me/").SetCookie(cookie) - -## Upload file - -httplib support mutil file upload, use `req.PostFile()` - - req := httplib.Post("http://beego.me/") - req.Param("username","astaxie") - req.PostFile("uploadfile1", "httplib.pdf") - str, err := req.String() - if err != nil { - // error - } - fmt.Println(str) - - -See godoc for further documentation and examples. - -* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/httplib/httplib.go b/httplib/httplib.go deleted file mode 100644 index 8ae95641..00000000 --- a/httplib/httplib.go +++ /dev/null @@ -1,697 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package httplib is used as http.Client -// Usage: -// -// import "github.com/astaxie/beego/httplib" -// -// b := httplib.Post("http://beego.me/") -// b.Param("username","astaxie") -// b.Param("password","123456") -// b.PostFile("uploadfile1", "httplib.pdf") -// b.PostFile("uploadfile2", "httplib.txt") -// str, err := b.String() -// if err != nil { -// t.Fatal(err) -// } -// fmt.Println(str) -// -// more docs http://beego.me/docs/module/httplib.md -package httplib - -import ( - "bytes" - "compress/gzip" - "crypto/tls" - "encoding/json" - "encoding/xml" - "io" - "io/ioutil" - "log" - "mime/multipart" - "net" - "net/http" - "net/http/cookiejar" - "net/http/httputil" - "net/url" - "os" - "path" - "strings" - "sync" - "time" - - "gopkg.in/yaml.v2" -) - -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, -} - -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex - -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} - -// SetDefaultSetting Overwrite default settings -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting -} - -// NewBeegoRequest return *BeegoHttpRequest with specific method -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { - var resp http.Response - u, err := url.Parse(rawurl) - if err != nil { - log.Println("Httplib:", err) - } - req := http.Request{ - URL: u, - Method: method, - Header: make(http.Header), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - } - return &BeegoHTTPRequest{ - url: rawurl, - req: &req, - params: map[string][]string{}, - files: map[string]string{}, - setting: defaultSetting, - resp: &resp, - } -} - -// Get returns *BeegoHttpRequest with GET method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Get(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "GET") -} - -// Post returns *BeegoHttpRequest with POST method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Post(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "POST") -} - -// Put returns *BeegoHttpRequest with PUT method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Put(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "PUT") -} - -// Delete returns *BeegoHttpRequest DELETE method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Delete(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "DELETE") -} - -// Head returns *BeegoHttpRequest with HEAD method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Head(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "HEAD") -} - -// BeegoHTTPSettings is the http.Client setting -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever - RetryDelay time.Duration -} - -// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -type BeegoHTTPRequest struct { - url string - req *http.Request - params map[string][]string - files map[string]string - setting BeegoHTTPSettings - resp *http.Response - body []byte - dump []byte -} - -// GetRequest return the request object -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) GetRequest() *http.Request { - return b.req -} - -// Setting Change request settings -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { - b.setting = setting - return b -} - -// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { - b.req.SetBasicAuth(username, password) - return b -} - -// SetEnableCookie sets enable/disable cookiejar -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { - b.setting.EnableCookie = enable - return b -} - -// SetUserAgent sets User-Agent header field -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { - b.setting.UserAgent = useragent - return b -} - -// Debug sets show debug or not when executing request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.setting.ShowDebug = isdebug - return b -} - -// Retries sets Retries times. -// default is 0 means no retried. -// -1 means retried forever. -// others means retried times. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { - b.setting.Retries = times - return b -} - -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { - b.setting.RetryDelay = delay - return b -} - -// DumpBody setting whether need to Dump the Body. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.setting.DumpBody = isdump - return b -} - -// DumpRequest return the DumpRequest -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.dump -} - -// SetTimeout sets connect time out and read-write time out for BeegoRequest. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { - b.setting.ConnectTimeout = connectTimeout - b.setting.ReadWriteTimeout = readWriteTimeout - return b -} - -// SetTLSClientConfig sets tls connection configurations if visiting https url. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { - b.setting.TLSClientConfig = config - return b -} - -// Header add header item string in request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { - b.req.Header.Set(key, value) - return b -} - -// SetHost set the request host -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { - b.req.Host = host - return b -} - -// SetProtocolVersion Set the protocol version for incoming requests. -// Client requests always use HTTP/1.1. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { - if len(vers) == 0 { - vers = "HTTP/1.1" - } - - major, minor, ok := http.ParseHTTPVersion(vers) - if ok { - b.req.Proto = vers - b.req.ProtoMajor = major - b.req.ProtoMinor = minor - } - - return b -} - -// SetCookie add cookie into request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { - b.req.Header.Add("Cookie", cookie.String()) - return b -} - -// SetTransport set the setting transport -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { - b.setting.Transport = transport - return b -} - -// SetProxy set the http proxy -// example: -// -// func(req *http.Request) (*url.URL, error) { -// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") -// return u, nil -// } -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { - b.setting.Proxy = proxy - return b -} - -// SetCheckRedirect specifies the policy for handling redirects. -// -// If CheckRedirect is nil, the Client uses its default policy, -// which is to stop after 10 consecutive requests. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { - b.setting.CheckRedirect = redirect - return b -} - -// Param adds query param in to request. -// params build query string as ?key1=value1&key2=value2... -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { - if param, ok := b.params[key]; ok { - b.params[key] = append(param, value) - } else { - b.params[key] = []string{value} - } - return b -} - -// PostFile add a post file to the request -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { - b.files[formname] = filename - return b -} - -// Body adds request raw body. -// it supports string and []byte. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { - switch t := data.(type) { - case string: - bf := bytes.NewBufferString(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) - case []byte: - bf := bytes.NewBuffer(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) - } - return b -} - -// XMLBody adds request raw body encoding by XML. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := xml.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/xml") - } - return b, nil -} - -// YAMLBody adds request raw body encoding by YAML. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := yaml.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/x+yaml") - } - return b, nil -} - -// JSONBody adds request raw body encoding by JSON. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := json.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/json") - } - return b, nil -} - -func (b *BeegoHTTPRequest) buildURL(paramBody string) { - // build GET url with query string - if b.req.Method == "GET" && len(paramBody) > 0 { - if strings.Contains(b.url, "?") { - b.url += "&" + paramBody - } else { - b.url = b.url + "?" + paramBody - } - return - } - - // build POST/PUT/PATCH url and body - if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { - // with files - if len(b.files) > 0 { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - //iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } - } - for k, v := range b.params { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - pw.Close() - }() - b.Header("Content-Type", bodyWriter.FormDataContentType()) - b.req.Body = ioutil.NopCloser(pr) - b.Header("Transfer-Encoding", "chunked") - return - } - - // with params - if len(paramBody) > 0 { - b.Header("Content-Type", "application/x-www-form-urlencoded") - b.Body(paramBody) - } - } -} - -func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { - if b.resp.StatusCode != 0 { - return b.resp, nil - } - resp, err := b.DoRequest() - if err != nil { - return nil, err - } - b.resp = resp - return resp, nil -} - -// DoRequest will do the client.Do -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } - } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] - } - - b.buildURL(paramBody) - urlParsed, err := url.Parse(b.url) - if err != nil { - return nil, err - } - - b.req.URL = urlParsed - - trans := b.setting.Transport - - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } - - client := &http.Client{ - Transport: trans, - Jar: jar, - } - - if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { - b.req.Header.Set("User-Agent", b.setting.UserAgent) - } - - if b.setting.CheckRedirect != nil { - client.CheckRedirect = b.setting.CheckRedirect - } - - if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) - } - b.dump = dump - } - // retries default value is 0, it will run once. - // retries equal to -1, it will run forever until success - // retries is setted, it will retries fixed times. - // Sleeps for a 400ms inbetween calls to reduce spam - for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { - resp, err = client.Do(b.req) - if err == nil { - break - } - time.Sleep(b.setting.RetryDelay) - } - return resp, err -} - -// String returns the body string in response. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) String() (string, error) { - data, err := b.Bytes() - if err != nil { - return "", err - } - - return string(data), nil -} - -// Bytes returns the body []byte in response. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { - if b.body != nil { - return b.body, nil - } - resp, err := b.getResponse() - if err != nil { - return nil, err - } - if resp.Body == nil { - return nil, nil - } - defer resp.Body.Close() - if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err - } - b.body, err = ioutil.ReadAll(reader) - return b.body, err - } - b.body, err = ioutil.ReadAll(resp.Body) - return b.body, err -} - -// ToFile saves the body data in response to one file. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToFile(filename string) error { - resp, err := b.getResponse() - if err != nil { - return err - } - if resp.Body == nil { - return nil - } - defer resp.Body.Close() - err = pathExistAndMkdir(filename) - if err != nil { - return err - } - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - _, err = io.Copy(f, resp.Body) - return err -} - -//Check that the file directory exists, there is no automatically created -func pathExistAndMkdir(filename string) (err error) { - filename = path.Dir(filename) - _, err = os.Stat(filename) - if err == nil { - return nil - } - if os.IsNotExist(err) { - err = os.MkdirAll(filename, os.ModePerm) - if err == nil { - return nil - } - } - return err -} - -// ToJSON returns the map that marshals from the body bytes as json in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return json.Unmarshal(data, v) -} - -// ToXML returns the map that marshals from the body bytes as xml in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToXML(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return xml.Unmarshal(data, v) -} - -// ToYAML returns the map that marshals from the body bytes as yaml in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return yaml.Unmarshal(data, v) -} - -// Response executes request client gets response mannually. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Response() (*http.Response, error) { - return b.getResponse() -} - -// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { - return func(netw, addr string) (net.Conn, error) { - conn, err := net.DialTimeout(netw, addr, cTimeout) - if err != nil { - return nil, err - } - err = conn.SetDeadline(time.Now().Add(rwTimeout)) - return conn, err - } -} diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go deleted file mode 100644 index f6be8571..00000000 --- a/httplib/httplib_test.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httplib - -import ( - "errors" - "io/ioutil" - "net" - "net/http" - "os" - "strings" - "testing" - "time" -) - -func TestResponse(t *testing.T) { - req := Get("http://httpbin.org/get") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) -} - -func TestDoRequest(t *testing.T) { - req := Get("https://goolnk.com/33BD2j") - retryAmount := 1 - req.Retries(1) - req.RetryDelay(1400 * time.Millisecond) - retryDelay := 1400 * time.Millisecond - - req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { - return errors.New("Redirect triggered") - } - - startTime := time.Now().UnixNano() / int64(time.Millisecond) - - _, err := req.Response() - if err == nil { - t.Fatal("Response should have yielded an error") - } - - endTime := time.Now().UnixNano() / int64(time.Millisecond) - elapsedTime := endTime - startTime - delayedTime := int64(retryAmount) * retryDelay.Milliseconds() - - if elapsedTime < delayedTime { - t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) - } - -} - -func TestGet(t *testing.T) { - req := Get("http://httpbin.org/get") - b, err := req.Bytes() - if err != nil { - t.Fatal(err) - } - t.Log(b) - - s, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(s) - - if string(b) != s { - t.Fatal("request data not match") - } -} - -func TestSimplePost(t *testing.T) { - v := "smallfish" - req := Post("http://httpbin.org/post") - req.Param("username", v) - - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in post") - } -} - -//func TestPostFile(t *testing.T) { -// v := "smallfish" -// req := Post("http://httpbin.org/post") -// req.Debug(true) -// req.Param("username", v) -// req.PostFile("uploadfile", "httplib_test.go") - -// str, err := req.String() -// if err != nil { -// t.Fatal(err) -// } -// t.Log(str) - -// n := strings.Index(str, v) -// if n == -1 { -// t.Fatal(v + " not found in post") -// } -//} - -func TestSimplePut(t *testing.T) { - str, err := Put("http://httpbin.org/put").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDelete(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDeleteParam(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestWithCookie(t *testing.T) { - v := "smallfish" - str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in cookie") - } -} - -func TestWithBasicAuth(t *testing.T) { - str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - n := strings.Index(str, "authenticated") - if n == -1 { - t.Fatal("authenticated not found in response") - } -} - -func TestWithUserAgent(t *testing.T) { - v := "beego" - str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestWithSetting(t *testing.T) { - v := "beego" - var setting BeegoHTTPSettings - setting.EnableCookie = true - setting.UserAgent = v - setting.Transport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 50, - IdleConnTimeout: 90 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - setting.ReadWriteTimeout = 5 * time.Second - SetDefaultSetting(setting) - - str, err := Get("http://httpbin.org/get").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestToJson(t *testing.T) { - req := Get("http://httpbin.org/ip") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) - - // httpbin will return http remote addr - type IP struct { - Origin string `json:"origin"` - } - var ip IP - err = req.ToJSON(&ip) - if err != nil { - t.Fatal(err) - } - t.Log(ip.Origin) - ips := strings.Split(ip.Origin, ",") - if len(ips) == 0 { - t.Fatal("response is not valid ip") - } - for i := range ips { - if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { - t.Fatal("response is not valid ip") - } - } - -} - -func TestToFile(t *testing.T) { - f := "beego_testfile" - req := Get("http://httpbin.org/ip") - err := req.ToFile(f) - if err != nil { - t.Fatal(err) - } - defer os.Remove(f) - b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { - t.Fatal(err) - } -} - -func TestToFileDir(t *testing.T) { - f := "./files/beego_testfile" - req := Get("http://httpbin.org/ip") - err := req.ToFile(f) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll("./files") - b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { - t.Fatal(err) - } -} - -func TestHeader(t *testing.T) { - req := Get("http://httpbin.org/headers") - req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} diff --git a/logs/README.md b/logs/README.md deleted file mode 100644 index c05bcc04..00000000 --- a/logs/README.md +++ /dev/null @@ -1,72 +0,0 @@ -## logs -logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . - - -## How to install? - - go get github.com/astaxie/beego/logs - - -## What adapters are supported? - -As of now this logs support console, file,smtp and conn. - - -## How to use it? - -First you must import it - -```golang -import ( - "github.com/astaxie/beego/logs" -) -``` - -Then init a Log (example with console adapter) - -```golang -log := logs.NewLogger(10000) -log.SetLogger("console", "") -``` - -> the first params stand for how many channel - -Use it like this: - -```golang -log.Trace("trace") -log.Info("info") -log.Warn("warning") -log.Debug("debug") -log.Critical("critical") -``` - -## File adapter - -Configure file adapter like this: - -```golang -log := NewLogger(10000) -log.SetLogger("file", `{"filename":"test.log"}`) -``` - -## Conn adapter - -Configure like this: - -```golang -log := NewLogger(1000) -log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) -log.Info("info") -``` - -## Smtp adapter - -Configure like this: - -```golang -log := NewLogger(10000) -log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) -log.Critical("sendmail critical") -time.Sleep(time.Second * 30) -``` diff --git a/logs/accesslog.go b/logs/accesslog.go deleted file mode 100644 index 9011b602..00000000 --- a/logs/accesslog.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bytes" - "encoding/json" - "fmt" - "strings" - "time" -) - -const ( - apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" - apacheFormat = "APACHE_FORMAT" - jsonFormat = "JSON_FORMAT" -) - -// AccessLogRecord struct for holding access log data. -type AccessLogRecord struct { - RemoteAddr string `json:"remote_addr"` - RequestTime time.Time `json:"request_time"` - RequestMethod string `json:"request_method"` - Request string `json:"request"` - ServerProtocol string `json:"server_protocol"` - Host string `json:"host"` - Status int `json:"status"` - BodyBytesSent int64 `json:"body_bytes_sent"` - ElapsedTime time.Duration `json:"elapsed_time"` - HTTPReferrer string `json:"http_referrer"` - HTTPUserAgent string `json:"http_user_agent"` - RemoteUser string `json:"remote_user"` -} - -func (r *AccessLogRecord) json() ([]byte, error) { - buffer := &bytes.Buffer{} - encoder := json.NewEncoder(buffer) - disableEscapeHTML(encoder) - - err := encoder.Encode(r) - return buffer.Bytes(), err -} - -func disableEscapeHTML(i interface{}) { - if e, ok := i.(interface { - SetEscapeHTML(bool) - }); ok { - e.SetEscapeHTML(false) - } -} - -// AccessLog - Format and print access log. -func AccessLog(r *AccessLogRecord, format string) { - var msg string - switch format { - case apacheFormat: - timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") - msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, - r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) - case jsonFormat: - fallthrough - default: - jsonData, err := r.json() - if err != nil { - msg = fmt.Sprintf(`{"Error": "%s"}`, err) - } else { - msg = string(jsonData) - } - } - beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) -} diff --git a/logs/alils/alils.go b/logs/alils/alils.go deleted file mode 100644 index 867ff4cb..00000000 --- a/logs/alils/alils.go +++ /dev/null @@ -1,186 +0,0 @@ -package alils - -import ( - "encoding/json" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/logs" - "github.com/gogo/protobuf/proto" -) - -const ( - // CacheSize set the flush size - CacheSize int = 64 - // Delimiter define the topic delimiter - Delimiter string = "##" -) - -// Config is the Config for Ali Log -type Config struct { - Project string `json:"project"` - Endpoint string `json:"endpoint"` - KeyID string `json:"key_id"` - KeySecret string `json:"key_secret"` - LogStore string `json:"log_store"` - Topics []string `json:"topics"` - Source string `json:"source"` - Level int `json:"level"` - FlushWhen int `json:"flush_when"` -} - -// aliLSWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. -type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - Config -} - -// NewAliLS create a new Logger -func NewAliLS() logs.Logger { - alils := new(aliLSWriter) - alils.Level = logs.LevelTrace - return alils -} - -// Init parse config and init struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { - - json.Unmarshal([]byte(jsonConfig), c) - - if c.FlushWhen > CacheSize { - c.FlushWhen = CacheSize - } - - prj := &LogProject{ - Name: c.Project, - Endpoint: c.Endpoint, - AccessKeyID: c.KeyID, - AccessKeySecret: c.KeySecret, - } - - c.store, err = prj.GetLogStore(c.LogStore) - if err != nil { - return err - } - - // Create default Log Group - c.group = append(c.group, &LogGroup{ - Topic: proto.String(""), - Source: proto.String(c.Source), - Logs: make([]*Log, 0, c.FlushWhen), - }) - - // Create other Log Group - c.groupMap = make(map[string]*LogGroup) - for _, topic := range c.Topics { - - lg := &LogGroup{ - Topic: proto.String(topic), - Source: proto.String(c.Source), - Logs: make([]*Log, 0, c.FlushWhen), - } - - c.group = append(c.group, lg) - c.groupMap[topic] = lg - } - - if len(c.group) == 1 { - c.withMap = false - } else { - c.withMap = true - } - - c.lock = &sync.Mutex{} - - return nil -} - -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { - - if level > c.Level { - return nil - } - - var topic string - var content string - var lg *LogGroup - if c.withMap { - - // Topic,LogGroup - strs := strings.SplitN(msg, Delimiter, 2) - if len(strs) == 2 { - pos := strings.LastIndex(strs[0], " ") - topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] - lg = c.groupMap[topic] - } - - // send to empty Topic - if lg == nil { - content = msg - lg = c.group[0] - } - } else { - content = msg - lg = c.group[0] - } - - c1 := &LogContent{ - Key: proto.String("msg"), - Value: proto.String(content), - } - - l := &Log{ - Time: proto.Uint32(uint32(when.Unix())), - Contents: []*LogContent{ - c1, - }, - } - - c.lock.Lock() - lg.Logs = append(lg.Logs, l) - c.lock.Unlock() - - if len(lg.Logs) >= c.FlushWhen { - c.flush(lg) - } - - return nil -} - -// Flush implementing method. empty. -func (c *aliLSWriter) Flush() { - - // flush all group - for _, lg := range c.group { - c.flush(lg) - } -} - -// Destroy destroy connection writer and close tcp listener. -func (c *aliLSWriter) Destroy() { -} - -func (c *aliLSWriter) flush(lg *LogGroup) { - - c.lock.Lock() - defer c.lock.Unlock() - err := c.store.PutLogs(lg) - if err != nil { - return - } - - lg.Logs = make([]*Log, 0, c.FlushWhen) -} - -func init() { - logs.Register(logs.AdapterAliLS, NewAliLS) -} diff --git a/logs/alils/config.go b/logs/alils/config.go deleted file mode 100755 index e8c24448..00000000 --- a/logs/alils/config.go +++ /dev/null @@ -1,13 +0,0 @@ -package alils - -const ( - version = "0.5.0" // SDK version - signatureMethod = "hmac-sha1" // Signature method - - // OffsetNewest stands for the log head offset, i.e. the offset that will be - // assigned to the next message that will be produced to the shard. - OffsetNewest = "end" - // OffsetOldest stands for the oldest offset available on the logstore for a - // shard. - OffsetOldest = "begin" -) diff --git a/logs/alils/log.pb.go b/logs/alils/log.pb.go deleted file mode 100755 index 601b0d78..00000000 --- a/logs/alils/log.pb.go +++ /dev/null @@ -1,1038 +0,0 @@ -package alils - -import ( - "fmt" - "io" - "math" - - "github.com/gogo/protobuf/proto" - github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -var ( - // ErrInvalidLengthLog invalid proto - ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") - // ErrIntOverflowLog overflow - ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") -) - -// Log define the proto Log -type Log struct { - Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` - Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset the Log -func (m *Log) Reset() { *m = Log{} } - -// String return the Compact Log -func (m *Log) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*Log) ProtoMessage() {} - -// GetTime return the Log's Time -func (m *Log) GetTime() uint32 { - if m != nil && m.Time != nil { - return *m.Time - } - return 0 -} - -// GetContents return the Log's Contents -func (m *Log) GetContents() []*LogContent { - if m != nil { - return m.Contents - } - return nil -} - -// LogContent define the Log content struct -type LogContent struct { - Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` - Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogContent -func (m *LogContent) Reset() { *m = LogContent{} } - -// String return the compact text -func (m *LogContent) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogContent) ProtoMessage() {} - -// GetKey return the Key -func (m *LogContent) GetKey() string { - if m != nil && m.Key != nil { - return *m.Key - } - return "" -} - -// GetValue return the Value -func (m *LogContent) GetValue() string { - if m != nil && m.Value != nil { - return *m.Value - } - return "" -} - -// LogGroup define the logs struct -type LogGroup struct { - Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` - Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` - Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` - Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogGroup -func (m *LogGroup) Reset() { *m = LogGroup{} } - -// String return the compact text -func (m *LogGroup) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogGroup) ProtoMessage() {} - -// GetLogs return the loggroup logs -func (m *LogGroup) GetLogs() []*Log { - if m != nil { - return m.Logs - } - return nil -} - -// GetReserved return Reserved -func (m *LogGroup) GetReserved() string { - if m != nil && m.Reserved != nil { - return *m.Reserved - } - return "" -} - -// GetTopic return Topic -func (m *LogGroup) GetTopic() string { - if m != nil && m.Topic != nil { - return *m.Topic - } - return "" -} - -// GetSource return Source -func (m *LogGroup) GetSource() string { - if m != nil && m.Source != nil { - return *m.Source - } - return "" -} - -// LogGroupList define the LogGroups -type LogGroupList struct { - LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogGroupList -func (m *LogGroupList) Reset() { *m = LogGroupList{} } - -// String return compact text -func (m *LogGroupList) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogGroupList) ProtoMessage() {} - -// GetLogGroups return the LogGroups -func (m *LogGroupList) GetLogGroups() []*LogGroup { - if m != nil { - return m.LogGroups - } - return nil -} - -// Marshal the logs to byte slice -func (m *Log) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo data -func (m *Log) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.Time == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") - } - data[i] = 0x8 - i++ - i = encodeVarintLog(data, i, uint64(*m.Time)) - if len(m.Contents) > 0 { - for _, msg := range m.Contents { - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogContent -func (m *LogContent) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo logcontent to data -func (m *LogContent) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.Key == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") - } - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Key))) - i += copy(data[i:], *m.Key) - - if m.Value == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") - } - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Value))) - i += copy(data[i:], *m.Value) - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogGroup -func (m *LogGroup) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo LogGroup to data -func (m *LogGroup) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if len(m.Logs) > 0 { - for _, msg := range m.Logs { - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.Reserved != nil { - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) - i += copy(data[i:], *m.Reserved) - } - if m.Topic != nil { - data[i] = 0x1a - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Topic))) - i += copy(data[i:], *m.Topic) - } - if m.Source != nil { - data[i] = 0x22 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Source))) - i += copy(data[i:], *m.Source) - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogGroupList -func (m *LogGroupList) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo LogGroupList to data -func (m *LogGroupList) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if len(m.LogGroups) > 0 { - for _, msg := range m.LogGroups { - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -func encodeFixed64Log(data []byte, offset int, v uint64) int { - data[offset] = uint8(v) - data[offset+1] = uint8(v >> 8) - data[offset+2] = uint8(v >> 16) - data[offset+3] = uint8(v >> 24) - data[offset+4] = uint8(v >> 32) - data[offset+5] = uint8(v >> 40) - data[offset+6] = uint8(v >> 48) - data[offset+7] = uint8(v >> 56) - return offset + 8 -} -func encodeFixed32Log(data []byte, offset int, v uint32) int { - data[offset] = uint8(v) - data[offset+1] = uint8(v >> 8) - data[offset+2] = uint8(v >> 16) - data[offset+3] = uint8(v >> 24) - return offset + 4 -} -func encodeVarintLog(data []byte, offset int, v uint64) int { - for v >= 1<<7 { - data[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - data[offset] = uint8(v) - return offset + 1 -} - -// Size return the log's size -func (m *Log) Size() (n int) { - var l int - _ = l - if m.Time != nil { - n += 1 + sovLog(uint64(*m.Time)) - } - if len(m.Contents) > 0 { - for _, e := range m.Contents { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogContent size based on Key and Value -func (m *LogContent) Size() (n int) { - var l int - _ = l - if m.Key != nil { - l = len(*m.Key) - n += 1 + l + sovLog(uint64(l)) - } - if m.Value != nil { - l = len(*m.Value) - n += 1 + l + sovLog(uint64(l)) - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogGroup size based on Logs -func (m *LogGroup) Size() (n int) { - var l int - _ = l - if len(m.Logs) > 0 { - for _, e := range m.Logs { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.Reserved != nil { - l = len(*m.Reserved) - n += 1 + l + sovLog(uint64(l)) - } - if m.Topic != nil { - l = len(*m.Topic) - n += 1 + l + sovLog(uint64(l)) - } - if m.Source != nil { - l = len(*m.Source) - n += 1 + l + sovLog(uint64(l)) - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogGroupList size -func (m *LogGroupList) Size() (n int) { - var l int - _ = l - if len(m.LogGroups) > 0 { - for _, e := range m.LogGroups { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -func sovLog(x uint64) (n int) { - for { - n++ - x >>= 7 - if x == 0 { - break - } - } - return n -} -func sozLog(x uint64) (n int) { - return sovLog((x << 1) ^ (x >> 63)) -} - -// Unmarshal data to log -func (m *Log) Unmarshal(data []byte) error { - var hasFields [1]uint64 - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Log: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) - } - var v uint32 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - v |= (uint32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - m.Time = &v - hasFields[0] |= uint64(0x00000001) - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Contents = append(m.Contents, &LogContent{}) - if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - if hasFields[0]&uint64(0x00000001) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogContent -func (m *LogContent) Unmarshal(data []byte) error { - var hasFields [1]uint64 - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Content: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Key = &s - iNdEx = postIndex - hasFields[0] |= uint64(0x00000001) - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Value = &s - iNdEx = postIndex - hasFields[0] |= uint64(0x00000002) - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - if hasFields[0]&uint64(0x00000001) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") - } - if hasFields[0]&uint64(0x00000002) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogGroup -func (m *LogGroup) Unmarshal(data []byte) error { - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Logs = append(m.Logs, &Log{}) - if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Reserved = &s - iNdEx = postIndex - case 3: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Topic = &s - iNdEx = postIndex - case 4: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Source = &s - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogGroupList -func (m *LogGroupList) Unmarshal(data []byte) error { - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.LogGroups = append(m.LogGroups, &LogGroup{}) - if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -func skipLog(data []byte) (n int, err error) { - l := len(data) - iNdEx := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if data[iNdEx-1] < 0x80 { - break - } - } - return iNdEx, nil - case 1: - iNdEx += 8 - return iNdEx, nil - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - iNdEx += length - if length < 0 { - return 0, ErrInvalidLengthLog - } - return iNdEx, nil - case 3: - for { - var innerWire uint64 - var start = iNdEx - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - innerWire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - innerWireType := int(innerWire & 0x7) - if innerWireType == 4 { - break - } - next, err := skipLog(data[start:]) - if err != nil { - return 0, err - } - iNdEx = start + next - } - return iNdEx, nil - case 4: - return iNdEx, nil - case 5: - iNdEx += 4 - return iNdEx, nil - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - } - panic("unreachable") -} diff --git a/logs/alils/log_config.go b/logs/alils/log_config.go deleted file mode 100755 index e8564efb..00000000 --- a/logs/alils/log_config.go +++ /dev/null @@ -1,42 +0,0 @@ -package alils - -// InputDetail define log detail -type InputDetail struct { - LogType string `json:"logType"` - LogPath string `json:"logPath"` - FilePattern string `json:"filePattern"` - LocalStorage bool `json:"localStorage"` - TimeFormat string `json:"timeFormat"` - LogBeginRegex string `json:"logBeginRegex"` - Regex string `json:"regex"` - Keys []string `json:"key"` - FilterKeys []string `json:"filterKey"` - FilterRegex []string `json:"filterRegex"` - TopicFormat string `json:"topicFormat"` -} - -// OutputDetail define the output detail -type OutputDetail struct { - Endpoint string `json:"endpoint"` - LogStoreName string `json:"logstoreName"` -} - -// LogConfig define Log Config -type LogConfig struct { - Name string `json:"configName"` - InputType string `json:"inputType"` - InputDetail InputDetail `json:"inputDetail"` - OutputType string `json:"outputType"` - OutputDetail OutputDetail `json:"outputDetail"` - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// GetAppliedMachineGroup returns applied machine group of this config. -func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { - groupNames, err = c.project.GetAppliedMachineGroups(c.Name) - return -} diff --git a/logs/alils/log_project.go b/logs/alils/log_project.go deleted file mode 100755 index 59db8cbf..00000000 --- a/logs/alils/log_project.go +++ /dev/null @@ -1,819 +0,0 @@ -/* -Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). - -For more description about SLS, please read this article: -http://gitlab.alibaba-inc.com/sls/doc. -*/ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" -) - -// Error message in SLS HTTP response. -type errorMessage struct { - Code string `json:"errorCode"` - Message string `json:"errorMessage"` -} - -// LogProject Define the Ali Project detail -type LogProject struct { - Name string // Project name - Endpoint string // IP or hostname of SLS endpoint - AccessKeyID string - AccessKeySecret string -} - -// NewLogProject creates a new SLS project. -func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { - p = &LogProject{ - Name: name, - Endpoint: endpoint, - AccessKeyID: AccessKeyID, - AccessKeySecret: accessKeySecret, - } - return p, nil -} - -// ListLogStore returns all logstore names of project p. -func (p *LogProject) ListLogStore() (storeNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores") - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Count int - LogStores []string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - storeNames = body.LogStores - - return -} - -// GetLogStore returns logstore according by logstore name. -func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/logstores/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - s = &LogStore{} - err = json.Unmarshal(buf, s) - if err != nil { - return - } - s.project = p - return -} - -// CreateLogStore creates a new logstore in SLS, -// where name is logstore name, -// and ttl is time-to-live(in day) of logs, -// and shardCnt is the number of shards. -func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { - - type Body struct { - Name string `json:"logstoreName"` - TTL int `json:"ttl"` - ShardCount int `json:"shardCount"` - } - - store := &Body{ - Name: name, - TTL: ttl, - ShardCount: shardCnt, - } - - body, err := json.Marshal(store) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/logstores", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to create logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteLogStore deletes a logstore according by logstore name. -func (p *LogProject) DeleteLogStore(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/logstores/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// UpdateLogStore updates a logstore according by logstore name, -// obviously we can't modify the logstore name itself. -func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { - - type Body struct { - Name string `json:"logstoreName"` - TTL int `json:"ttl"` - ShardCount int `json:"shardCount"` - } - - store := &Body{ - Name: name, - TTL: ttl, - ShardCount: shardCnt, - } - - body, err := json.Marshal(store) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/logstores", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// ListMachineGroup returns machine group name list and the total number of machine groups. -// The offset starts from 0 and the size is the max number of machine groups could be returned. -func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - if size <= 0 { - size = 500 - } - - uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - MachineGroups []string - Count int - Total int - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - m = body.MachineGroups - total = body.Total - - return -} - -// GetMachineGroup retruns machine group according by machine group name. -func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/machinegroups/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get machine group:%v", name) - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - m = &MachineGroup{} - err = json.Unmarshal(buf, m) - if err != nil { - return - } - m.project = p - return -} - -// CreateMachineGroup creates a new machine group in SLS. -func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { - - body, err := json.Marshal(m) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/machinegroups", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to create machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// UpdateMachineGroup updates a machine group. -func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { - - body, err := json.Marshal(m) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteMachineGroup deletes machine group according machine group name. -func (p *LogProject) DeleteMachineGroup(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// ListConfig returns config names list and the total number of configs. -// The offset starts from 0 and the size is the max number of configs could be returned. -func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - if size <= 0 { - size = 100 - } - - uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Total int - Configs []string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - cfgNames = body.Configs - total = body.Total - return -} - -// GetConfig returns config according by config name. -func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/configs/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - c = &LogConfig{} - err = json.Unmarshal(buf, c) - if err != nil { - return - } - c.project = p - return -} - -// UpdateConfig updates a config. -func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { - - body, err := json.Marshal(c) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/configs/"+c.Name, h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// CreateConfig creates a new config in SLS. -func (p *LogProject) CreateConfig(c *LogConfig) (err error) { - - body, err := json.Marshal(c) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/configs", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteConfig deletes a config according by config name. -func (p *LogProject) DeleteConfig(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/configs/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// GetAppliedMachineGroups returns applied machine group names list according config name. -func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/configs/%v/machinegroups", confName) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get applied machine groups") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Count int - Machinegroups []string - } - - body := &Body{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - groupNames = body.Machinegroups - return -} - -// GetAppliedConfigs returns applied config names list according machine group name groupName. -func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to applied configs") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Cfg struct { - Count int `json:"count"` - Configs []string `json:"configs"` - } - - body := &Cfg{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - confNames = body.Configs - return -} - -// ApplyConfigToMachineGroup applies config to machine group. -func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) - r, err := request(p, "PUT", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to apply config to machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// RemoveConfigFromMachineGroup removes config from machine group. -func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) - r, err := request(p, "DELETE", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to remove config from machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} diff --git a/logs/alils/log_store.go b/logs/alils/log_store.go deleted file mode 100755 index fa502736..00000000 --- a/logs/alils/log_store.go +++ /dev/null @@ -1,271 +0,0 @@ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" - "strconv" - - lz4 "github.com/cloudflare/golz4" - "github.com/gogo/protobuf/proto" -) - -// LogStore Store the logs -type LogStore struct { - Name string `json:"logstoreName"` - TTL int - ShardCount int - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// Shard define the Log Shard -type Shard struct { - ShardID int `json:"shardID"` -} - -// ListShards returns shard id list of this logstore. -func (s *LogStore) ListShards() (shardIDs []int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores/%v/shards", s.Name) - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - var shards []*Shard - err = json.Unmarshal(buf, &shards) - if err != nil { - return - } - - for _, v := range shards { - shardIDs = append(shardIDs, v.ShardID) - } - return -} - -// PutLogs put logs into logstore. -// The callers should transform user logs into LogGroup. -func (s *LogStore) PutLogs(lg *LogGroup) (err error) { - body, err := proto.Marshal(lg) - if err != nil { - return - } - - // Compresse body with lz4 - out := make([]byte, lz4.CompressBound(body)) - n, err := lz4.Compress(body, out) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-compresstype": "lz4", - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/x-protobuf", - } - - uri := fmt.Sprintf("/logstores/%v", s.Name) - r, err := request(s.project, "POST", uri, h, out[:n]) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to put logs") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// GetCursor gets log cursor of one shard specified by shardID. -// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". -// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore -func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", - s.Name, shardID, from) - - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get cursor") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Cursor string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - cursor = body.Cursor - return -} - -// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. -// The logGroupMaxCount is the max number of logGroup could be returned. -// The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogsBytes(shardID int, cursor string, - logGroupMaxCount int) (out []byte, nextCursor string, err error) { - - h := map[string]string{ - "x-sls-bodyrawsize": "0", - "Accept": "application/x-protobuf", - "Accept-Encoding": "lz4", - } - - uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", - s.Name, shardID, cursor, logGroupMaxCount) - - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get cursor") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - v, ok := r.Header["X-Sls-Compresstype"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-compresstype' header") - return - } - if v[0] != "lz4" { - err = fmt.Errorf("unexpected compress type:%v", v[0]) - return - } - - v, ok = r.Header["X-Sls-Cursor"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-cursor' header") - return - } - nextCursor = v[0] - - v, ok = r.Header["X-Sls-Bodyrawsize"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") - return - } - bodyRawSize, err := strconv.Atoi(v[0]) - if err != nil { - return - } - - out = make([]byte, bodyRawSize) - err = lz4.Uncompress(buf, out) - if err != nil { - return - } - - return -} - -// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API -func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { - - gl = &LogGroupList{} - err = proto.Unmarshal(data, gl) - if err != nil { - return - } - - return -} - -// GetLogs gets logs from shard specified by shardID according cursor. -// The logGroupMaxCount is the max number of logGroup could be returned. -// The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogs(shardID int, cursor string, - logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { - - out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) - if err != nil { - return - } - - gl, err = LogsBytesDecode(out) - if err != nil { - return - } - - return -} diff --git a/logs/alils/machine_group.go b/logs/alils/machine_group.go deleted file mode 100755 index b6c69a14..00000000 --- a/logs/alils/machine_group.go +++ /dev/null @@ -1,91 +0,0 @@ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" -) - -// MachineGroupAttribute define the Attribute -type MachineGroupAttribute struct { - ExternalName string `json:"externalName"` - TopicName string `json:"groupTopic"` -} - -// MachineGroup define the machine Group -type MachineGroup struct { - Name string `json:"groupName"` - Type string `json:"groupType"` - MachineIDType string `json:"machineIdentifyType"` - MachineIDList []string `json:"machineList"` - - Attribute MachineGroupAttribute `json:"groupAttribute"` - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// Machine define the Machine -type Machine struct { - IP string - UniqueID string `json:"machine-uniqueid"` - UserdefinedID string `json:"userdefined-id"` -} - -// MachineList define the Machine List -type MachineList struct { - Total int - Machines []*Machine -} - -// ListMachines returns machine list of this machine group. -func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) - r, err := request(m.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to remove config from machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - body := &MachineList{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - ms = body.Machines - total = body.Total - - return -} - -// GetAppliedConfigs returns applied configs of this machine group. -func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { - confNames, err = m.project.GetAppliedConfigs(m.Name) - return -} diff --git a/logs/alils/request.go b/logs/alils/request.go deleted file mode 100755 index 50d9c43c..00000000 --- a/logs/alils/request.go +++ /dev/null @@ -1,62 +0,0 @@ -package alils - -import ( - "bytes" - "crypto/md5" - "fmt" - "net/http" -) - -// request sends a request to SLS. -func request(project *LogProject, method, uri string, headers map[string]string, - body []byte) (resp *http.Response, err error) { - - // The caller should provide 'x-sls-bodyrawsize' header - if _, ok := headers["x-sls-bodyrawsize"]; !ok { - err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") - return - } - - // SLS public request headers - headers["Host"] = project.Name + "." + project.Endpoint - headers["Date"] = nowRFC1123() - headers["x-sls-apiversion"] = version - headers["x-sls-signaturemethod"] = signatureMethod - if body != nil { - bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) - headers["Content-MD5"] = bodyMD5 - - if _, ok := headers["Content-Type"]; !ok { - err = fmt.Errorf("Can't find 'Content-Type' header") - return - } - } - - // Calc Authorization - // Authorization = "SLS :" - digest, err := signature(project, method, uri, headers) - if err != nil { - return - } - auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) - headers["Authorization"] = auth - - // Initialize http request - reader := bytes.NewReader(body) - urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) - req, err := http.NewRequest(method, urlStr, reader) - if err != nil { - return - } - for k, v := range headers { - req.Header.Add(k, v) - } - - // Get ready to do request - resp, err = http.DefaultClient.Do(req) - if err != nil { - return - } - - return -} diff --git a/logs/alils/signature.go b/logs/alils/signature.go deleted file mode 100755 index 2d611307..00000000 --- a/logs/alils/signature.go +++ /dev/null @@ -1,111 +0,0 @@ -package alils - -import ( - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "fmt" - "net/url" - "sort" - "strings" - "time" -) - -// GMT location -var gmtLoc = time.FixedZone("GMT", 0) - -// NowRFC1123 returns now time in RFC1123 format with GMT timezone, -// eg. "Mon, 02 Jan 2006 15:04:05 GMT". -func nowRFC1123() string { - return time.Now().In(gmtLoc).Format(time.RFC1123) -} - -// signature calculates a request's signature digest. -func signature(project *LogProject, method, uri string, - headers map[string]string) (digest string, err error) { - var contentMD5, contentType, date, canoHeaders, canoResource string - var slsHeaderKeys sort.StringSlice - - // SignString = VERB + "\n" - // + CONTENT-MD5 + "\n" - // + CONTENT-TYPE + "\n" - // + DATE + "\n" - // + CanonicalizedSLSHeaders + "\n" - // + CanonicalizedResource - - if val, ok := headers["Content-MD5"]; ok { - contentMD5 = val - } - - if val, ok := headers["Content-Type"]; ok { - contentType = val - } - - date, ok := headers["Date"] - if !ok { - err = fmt.Errorf("Can't find 'Date' header") - return - } - - // Calc CanonicalizedSLSHeaders - slsHeaders := make(map[string]string, len(headers)) - for k, v := range headers { - l := strings.TrimSpace(strings.ToLower(k)) - if strings.HasPrefix(l, "x-sls-") { - slsHeaders[l] = strings.TrimSpace(v) - slsHeaderKeys = append(slsHeaderKeys, l) - } - } - - sort.Sort(slsHeaderKeys) - for i, k := range slsHeaderKeys { - canoHeaders += k + ":" + slsHeaders[k] - if i+1 < len(slsHeaderKeys) { - canoHeaders += "\n" - } - } - - // Calc CanonicalizedResource - u, err := url.Parse(uri) - if err != nil { - return - } - - canoResource += url.QueryEscape(u.Path) - if u.RawQuery != "" { - var keys sort.StringSlice - - vals := u.Query() - for k := range vals { - keys = append(keys, k) - } - - sort.Sort(keys) - canoResource += "?" - for i, k := range keys { - if i > 0 { - canoResource += "&" - } - - for _, v := range vals[k] { - canoResource += k + "=" + v - } - } - } - - signStr := method + "\n" + - contentMD5 + "\n" + - contentType + "\n" + - date + "\n" + - canoHeaders + "\n" + - canoResource - - // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) - mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) - _, err = mac.Write([]byte(signStr)) - if err != nil { - return - } - digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) - return -} diff --git a/logs/conn.go b/logs/conn.go deleted file mode 100644 index 74c458ab..00000000 --- a/logs/conn.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "io" - "net" - "time" -) - -// connWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. -type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` -} - -// NewConn create new ConnWrite returning as LoggerInterface. -func NewConn() Logger { - conn := new(connWriter) - conn.Level = LevelTrace - return conn -} - -// Init init connection writer with json config. -// json config only need key "level". -func (c *connWriter) Init(jsonConfig string) error { - return json.Unmarshal([]byte(jsonConfig), c) -} - -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { - return nil - } - if c.needToConnectOnMsg() { - err := c.connect() - if err != nil { - return err - } - } - - if c.ReconnectOnMsg { - defer c.innerWriter.Close() - } - - _, err := c.lg.writeln(when, msg) - if err != nil { - return err - } - return nil -} - -// Flush implementing method. empty. -func (c *connWriter) Flush() { - -} - -// Destroy destroy connection writer and close tcp listener. -func (c *connWriter) Destroy() { - if c.innerWriter != nil { - c.innerWriter.Close() - } -} - -func (c *connWriter) connect() error { - if c.innerWriter != nil { - c.innerWriter.Close() - c.innerWriter = nil - } - - conn, err := net.Dial(c.Net, c.Addr) - if err != nil { - return err - } - - if tcpConn, ok := conn.(*net.TCPConn); ok { - tcpConn.SetKeepAlive(true) - } - - c.innerWriter = conn - c.lg = newLogWriter(conn) - return nil -} - -func (c *connWriter) needToConnectOnMsg() bool { - if c.Reconnect { - return true - } - - if c.innerWriter == nil { - return true - } - - return c.ReconnectOnMsg -} - -func init() { - Register(AdapterConn, NewConn) -} diff --git a/logs/conn_test.go b/logs/conn_test.go deleted file mode 100644 index 7cfb4d2b..00000000 --- a/logs/conn_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "net" - "os" - "testing" -) - -// ConnTCPListener takes a TCP listener and accepts n TCP connections -// Returns connections using connChan -func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { - - // Listen and accept n incoming connections - for i := 0; i < n; i++ { - conn, err := ln.Accept() - if err != nil { - t.Log("Error accepting connection: ", err.Error()) - os.Exit(1) - } - - // Send accepted connection to channel - connChan <- conn - } - ln.Close() - close(connChan) -} - -func TestConn(t *testing.T) { - log := NewLogger(1000) - log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) - log.Informational("informational") -} - -func TestReconnect(t *testing.T) { - // Setup connection listener - newConns := make(chan net.Conn) - connNum := 2 - ln, err := net.Listen("tcp", ":6002") - if err != nil { - t.Log("Error listening:", err.Error()) - os.Exit(1) - } - go connTCPListener(t, connNum, ln, newConns) - - // Setup logger - log := NewLogger(1000) - log.SetPrefix("test") - log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) - log.Informational("informational 1") - - // Refuse first connection - first := <-newConns - first.Close() - - // Send another log after conn closed - log.Informational("informational 2") - - // Check if there was a second connection attempt - // close this because we moved the codes to pkg/logs - // select { - // case second := <-newConns: - // second.Close() - // default: - // t.Error("Did not reconnect") - // } -} diff --git a/logs/console.go b/logs/console.go deleted file mode 100644 index 3dcaee1d..00000000 --- a/logs/console.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "os" - "strings" - "time" - - "github.com/shiena/ansicolor" -) - -// brush is a color join function -type brush func(string) string - -// newBrush return a fix color Brush -func newBrush(color string) brush { - pre := "\033[" - reset := "\033[0m" - return func(text string) string { - return pre + color + "m" + text + reset - } -} - -var colors = []brush{ - newBrush("1;37"), // Emergency white - newBrush("1;36"), // Alert cyan - newBrush("1;35"), // Critical magenta - newBrush("1;31"), // Error red - newBrush("1;33"), // Warning yellow - newBrush("1;32"), // Notice green - newBrush("1;34"), // Informational blue - newBrush("1;44"), // Debug Background blue -} - -// consoleWriter implements LoggerInterface and writes messages to terminal. -type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color -} - -// NewConsole create ConsoleWriter returning as LoggerInterface. -func NewConsole() Logger { - cw := &consoleWriter{ - lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), - Level: LevelDebug, - Colorful: true, - } - return cw -} - -// Init init console logger. -// jsonConfig like '{"level":LevelTrace}'. -func (c *consoleWriter) Init(jsonConfig string) error { - if len(jsonConfig) == 0 { - return nil - } - return json.Unmarshal([]byte(jsonConfig), c) -} - -// WriteMsg write message in console. -func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { - return nil - } - if c.Colorful { - msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) - } - c.lg.writeln(when, msg) - return nil -} - -// Destroy implementing method. empty. -func (c *consoleWriter) Destroy() { - -} - -// Flush implementing method. empty. -func (c *consoleWriter) Flush() { - -} - -func init() { - Register(AdapterConsole, NewConsole) -} diff --git a/logs/console_test.go b/logs/console_test.go deleted file mode 100644 index 4bc45f57..00000000 --- a/logs/console_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "testing" - "time" -) - -// Try each log level in decreasing order of priority. -func testConsoleCalls(bl *BeeLogger) { - bl.Emergency("emergency") - bl.Alert("alert") - bl.Critical("critical") - bl.Error("error") - bl.Warning("warning") - bl.Notice("notice") - bl.Informational("informational") - bl.Debug("debug") -} - -// Test console logging by visually comparing the lines being output with and -// without a log level specification. -func TestConsole(t *testing.T) { - log1 := NewLogger(10000) - log1.EnableFuncCallDepth(true) - log1.SetLogger("console", "") - testConsoleCalls(log1) - - log2 := NewLogger(100) - log2.SetLogger("console", `{"level":3}`) - testConsoleCalls(log2) -} - -// Test console without color -func TestConsoleNoColor(t *testing.T) { - log := NewLogger(100) - log.SetLogger("console", `{"color":false}`) - testConsoleCalls(log) -} - -// Test console async -func TestConsoleAsync(t *testing.T) { - log := NewLogger(100) - log.SetLogger("console") - log.Async() - //log.Close() - testConsoleCalls(log) - for len(log.msgChan) != 0 { - time.Sleep(1 * time.Millisecond) - } -} diff --git a/logs/es/es.go b/logs/es/es.go deleted file mode 100644 index 2b7b1710..00000000 --- a/logs/es/es.go +++ /dev/null @@ -1,102 +0,0 @@ -package es - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/url" - "strings" - "time" - - "github.com/elastic/go-elasticsearch/v6" - "github.com/elastic/go-elasticsearch/v6/esapi" - - "github.com/astaxie/beego/logs" -) - -// NewES return a LoggerInterface -func NewES() logs.Logger { - cw := &esLogger{ - Level: logs.LevelDebug, - } - return cw -} - -// esLogger will log msg into ES -// before you using this implementation, -// please import this package -// usually means that you can import this package in your main package -// for example, anonymous: -// import _ "github.com/astaxie/beego/logs/es" -type esLogger struct { - *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` -} - -// {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), el) - if err != nil { - return err - } - if el.DSN == "" { - return errors.New("empty dsn") - } else if u, err := url.Parse(el.DSN); err != nil { - return err - } else if u.Path == "" { - return errors.New("missing prefix") - } else { - conn, err := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{el.DSN}, - }) - if err != nil { - return err - } - el.Client = conn - } - return nil -} - -// WriteMsg will write the msg and level into es -func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { - if level > el.Level { - return nil - } - - idx := LogDocument{ - Timestamp: when.Format(time.RFC3339), - Msg: msg, - } - - body, err := json.Marshal(idx) - if err != nil { - return err - } - req := esapi.IndexRequest{ - Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), - DocumentType: "logs", - Body: strings.NewReader(string(body)), - } - _, err = req.Do(context.Background(), el.Client) - return err -} - -// Destroy is a empty method -func (el *esLogger) Destroy() { -} - -// Flush is a empty method -func (el *esLogger) Flush() { - -} - -type LogDocument struct { - Timestamp string `json:"timestamp"` - Msg string `json:"msg"` -} - -func init() { - logs.Register(logs.AdapterEs, NewES) -} diff --git a/logs/file.go b/logs/file.go deleted file mode 100644 index 40a3572a..00000000 --- a/logs/file.go +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" -) - -// fileLogWriter implements LoggerInterface. -// It writes messages by lines limit, file size limit, or time frequency. -type fileLogWriter struct { - sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize - // The opened file - Filename string `json:"filename"` - fileWriter *os.File - - // Rotate at line - MaxLines int `json:"maxlines"` - maxLinesCurLines int - - MaxFiles int `json:"maxfiles"` - MaxFilesCurFiles int - - // Rotate at size - MaxSize int `json:"maxsize"` - maxSizeCurSize int - - // Rotate daily - Daily bool `json:"daily"` - MaxDays int64 `json:"maxdays"` - dailyOpenDate int - dailyOpenTime time.Time - - // Rotate hourly - Hourly bool `json:"hourly"` - MaxHours int64 `json:"maxhours"` - hourlyOpenDate int - hourlyOpenTime time.Time - - Rotate bool `json:"rotate"` - - Level int `json:"level"` - - Perm string `json:"perm"` - - RotatePerm string `json:"rotateperm"` - - fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix -} - -// newFileWriter create a FileLogWriter returning as LoggerInterface. -func newFileWriter() Logger { - w := &fileLogWriter{ - Daily: true, - MaxDays: 7, - Hourly: false, - MaxHours: 168, - Rotate: true, - RotatePerm: "0440", - Level: LevelTrace, - Perm: "0660", - MaxLines: 10000000, - MaxFiles: 999, - MaxSize: 1 << 28, - } - return w -} - -// Init file logger with json config. -// jsonConfig like: -// { -// "filename":"logs/beego.log", -// "maxLines":10000, -// "maxsize":1024, -// "daily":true, -// "maxDays":15, -// "rotate":true, -// "perm":"0600" -// } -func (w *fileLogWriter) Init(jsonConfig string) error { - err := json.Unmarshal([]byte(jsonConfig), w) - if err != nil { - return err - } - if len(w.Filename) == 0 { - return errors.New("jsonconfig must have filename") - } - w.suffix = filepath.Ext(w.Filename) - w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) - if w.suffix == "" { - w.suffix = ".log" - } - err = w.startLogger() - return err -} - -// start file logger. create log file and set to locker-inside file writer. -func (w *fileLogWriter) startLogger() error { - file, err := w.createLogFile() - if err != nil { - return err - } - if w.fileWriter != nil { - w.fileWriter.Close() - } - w.fileWriter = file - return w.initFd() -} - -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { - return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || - (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || - (w.Daily && day != w.dailyOpenDate) -} - -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { - return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || - (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || - (w.Hourly && hour != w.hourlyOpenDate) - -} - -// WriteMsg write logger message into file. -func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > w.Level { - return nil - } - hd, d, h := formatTimeHeader(when) - msg = string(hd) + msg + "\n" - if w.Rotate { - w.RLock() - if w.needRotateHourly(len(msg), h) { - w.RUnlock() - w.Lock() - if w.needRotateHourly(len(msg), h) { - if err := w.doRotate(when); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() - } else if w.needRotateDaily(len(msg), d) { - w.RUnlock() - w.Lock() - if w.needRotateDaily(len(msg), d) { - if err := w.doRotate(when); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() - } else { - w.RUnlock() - } - } - - w.Lock() - _, err := w.fileWriter.Write([]byte(msg)) - if err == nil { - w.maxLinesCurLines++ - w.maxSizeCurSize += len(msg) - } - w.Unlock() - return err -} - -func (w *fileLogWriter) createLogFile() (*os.File, error) { - // Open the log file - perm, err := strconv.ParseInt(w.Perm, 8, 64) - if err != nil { - return nil, err - } - - filepath := path.Dir(w.Filename) - os.MkdirAll(filepath, os.FileMode(perm)) - - fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) - if err == nil { - // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask - os.Chmod(w.Filename, os.FileMode(perm)) - } - return fd, err -} - -func (w *fileLogWriter) initFd() error { - fd := w.fileWriter - fInfo, err := fd.Stat() - if err != nil { - return fmt.Errorf("get stat err: %s", err) - } - w.maxSizeCurSize = int(fInfo.Size()) - w.dailyOpenTime = time.Now() - w.dailyOpenDate = w.dailyOpenTime.Day() - w.hourlyOpenTime = time.Now() - w.hourlyOpenDate = w.hourlyOpenTime.Hour() - w.maxLinesCurLines = 0 - if w.Hourly { - go w.hourlyRotate(w.hourlyOpenTime) - } else if w.Daily { - go w.dailyRotate(w.dailyOpenTime) - } - if fInfo.Size() > 0 && w.MaxLines > 0 { - count, err := w.lines() - if err != nil { - return err - } - w.maxLinesCurLines = count - } - return nil -} - -func (w *fileLogWriter) dailyRotate(openTime time.Time) { - y, m, d := openTime.Add(24 * time.Hour).Date() - nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) - tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) - <-tm.C - w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() -} - -func (w *fileLogWriter) hourlyRotate(openTime time.Time) { - y, m, d := openTime.Add(1 * time.Hour).Date() - h, _, _ := openTime.Add(1 * time.Hour).Clock() - nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) - tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) - <-tm.C - w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() -} - -func (w *fileLogWriter) lines() (int, error) { - fd, err := os.Open(w.Filename) - if err != nil { - return 0, err - } - defer fd.Close() - - buf := make([]byte, 32768) // 32k - count := 0 - lineSep := []byte{'\n'} - - for { - c, err := fd.Read(buf) - if err != nil && err != io.EOF { - return count, err - } - - count += bytes.Count(buf[:c], lineSep) - - if err == io.EOF { - break - } - } - - return count, nil -} - -// DoRotate means it need to write file in new file. -// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) -func (w *fileLogWriter) doRotate(logTime time.Time) error { - // file exists - // Find the next available number - num := w.MaxFilesCurFiles + 1 - fName := "" - format := "" - var openTime time.Time - rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) - if err != nil { - return err - } - - _, err = os.Lstat(w.Filename) - if err != nil { - //even if the file is not exist or other ,we should RESTART the logger - goto RESTART_LOGGER - } - - if w.Hourly { - format = "2006010215" - openTime = w.hourlyOpenTime - } else if w.Daily { - format = "2006-01-02" - openTime = w.dailyOpenTime - } - - // only when one of them be setted, then the file would be splited - if w.MaxLines > 0 || w.MaxSize > 0 { - for ; err == nil && num <= w.MaxFiles; num++ { - fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) - _, err = os.Lstat(fName) - } - } else { - fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) - _, err = os.Lstat(fName) - w.MaxFilesCurFiles = num - } - - // return error if the last file checked still existed - if err == nil { - return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) - } - - // close fileWriter before rename - w.fileWriter.Close() - - // Rename the file to its new found name - // even if occurs error,we MUST guarantee to restart new logger - err = os.Rename(w.Filename, fName) - if err != nil { - goto RESTART_LOGGER - } - - err = os.Chmod(fName, os.FileMode(rotatePerm)) - -RESTART_LOGGER: - - startLoggerErr := w.startLogger() - go w.deleteOldLog() - - if startLoggerErr != nil { - return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) - } - if err != nil { - return fmt.Errorf("Rotate: %s", err) - } - return nil -} - -func (w *fileLogWriter) deleteOldLog() { - dir := filepath.Dir(w.Filename) - absolutePath, err := filepath.EvalSymlinks(w.Filename) - if err == nil { - dir = filepath.Dir(absolutePath) - } - filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { - defer func() { - if r := recover(); r != nil { - fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) - } - }() - - if info == nil { - return - } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } - return - }) -} - -// Destroy close the file description, close file writer. -func (w *fileLogWriter) Destroy() { - w.fileWriter.Close() -} - -// Flush flush file logger. -// there are no buffering messages in file logger in memory. -// flush file means sync file from disk. -func (w *fileLogWriter) Flush() { - w.fileWriter.Sync() -} - -func init() { - Register(AdapterFile, newFileWriter) -} diff --git a/logs/file_test.go b/logs/file_test.go deleted file mode 100644 index 385eac43..00000000 --- a/logs/file_test.go +++ /dev/null @@ -1,420 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bufio" - "fmt" - "io/ioutil" - "os" - "strconv" - "testing" - "time" -) - -func TestFilePerm(t *testing.T) { - log := NewLogger(10000) - // use 0666 as test perm cause the default umask is 022 - log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - file, err := os.Stat("test.log") - if err != nil { - t.Fatal(err) - } - if file.Mode() != 0666 { - t.Fatal("unexpected log file permission") - } - os.Remove("test.log") -} - -func TestFile1(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test.log"}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - f, err := os.Open("test.log") - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lineNum++ - } - } - var expected = LevelDebug + 1 - if lineNum != expected { - t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") - } - os.Remove("test.log") -} - -func TestFile2(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError)) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - f, err := os.Open("test2.log") - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lineNum++ - } - } - var expected = LevelError + 1 - if lineNum != expected { - t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") - } - os.Remove("test2.log") -} - -func TestFileDailyRotate_01(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" - b, err := exists(rotateName) - if !b || err != nil { - os.Remove("test3.log") - t.Fatal("rotate not generated") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileDailyRotate_02(t *testing.T) { - fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileRotate(t, fn1, fn2, true, false) -} - -func TestFileDailyRotate_03(t *testing.T) { - fn1 := "rotate_day.log" - fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" - os.Create(fn) - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileRotate(t, fn1, fn2, true, false) - os.Remove(fn) -} - -func TestFileDailyRotate_04(t *testing.T) { - fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileDailyRotate(t, fn1, fn2) -} - -func TestFileDailyRotate_05(t *testing.T) { - fn1 := "rotate_day.log" - fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" - os.Create(fn) - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileDailyRotate(t, fn1, fn2) - os.Remove(fn) -} -func TestFileDailyRotate_06(t *testing.T) { //test file mode - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" - s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { - os.Remove(rotateName) - os.Remove("test3.log") - t.Fatal("rotate file mode error") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileHourlyRotate_01(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" - b, err := exists(rotateName) - if !b || err != nil { - os.Remove("test3.log") - t.Fatal("rotate not generated") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileHourlyRotate_02(t *testing.T) { - fn1 := "rotate_hour.log" - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileRotate(t, fn1, fn2, false, true) -} - -func TestFileHourlyRotate_03(t *testing.T) { - fn1 := "rotate_hour.log" - fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" - os.Create(fn) - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileRotate(t, fn1, fn2, false, true) - os.Remove(fn) -} - -func TestFileHourlyRotate_04(t *testing.T) { - fn1 := "rotate_hour.log" - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileHourlyRotate(t, fn1, fn2) -} - -func TestFileHourlyRotate_05(t *testing.T) { - fn1 := "rotate_hour.log" - fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" - os.Create(fn) - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileHourlyRotate(t, fn1, fn2) - os.Remove(fn) -} - -func TestFileHourlyRotate_06(t *testing.T) { //test file mode - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" - s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { - os.Remove(rotateName) - os.Remove("test3.log") - t.Fatal("rotate file mode error") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { - fw := &fileLogWriter{ - Daily: daily, - MaxDays: 7, - Hourly: hourly, - MaxHours: 168, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } - - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } - - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) - - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.Log(err) - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} - -func testFileDailyRotate(t *testing.T, fn1, fn2 string) { - fw := &fileLogWriter{ - Daily: true, - MaxDays: 7, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) - today = today.Add(-1 * time.Second) - fw.dailyRotate(today) - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.FailNow() - } - content, err := ioutil.ReadFile(file) - if err != nil { - t.FailNow() - } - if len(content) > 0 { - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} - -func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { - fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() - hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location()) - hour = hour.Add(-1 * time.Second) - fw.hourlyRotate(hour) - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.FailNow() - } - content, err := ioutil.ReadFile(file) - if err != nil { - t.FailNow() - } - if len(content) > 0 { - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -func BenchmarkFile(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileAsynchronous(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.Async() - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileCallDepth(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.EnableFuncCallDepth(true) - log.SetLogFuncCallDepth(2) - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileAsynchronousCallDepth(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.EnableFuncCallDepth(true) - log.SetLogFuncCallDepth(2) - log.Async() - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileOnGoroutine(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - for i := 0; i < b.N; i++ { - go log.Debug("debug") - } - os.Remove("test4.log") -} diff --git a/logs/jianliao.go b/logs/jianliao.go deleted file mode 100644 index 88ba0f9a..00000000 --- a/logs/jianliao.go +++ /dev/null @@ -1,72 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "time" -) - -// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook -type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` -} - -// newJLWriter create jiaoliao writer. -func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} -} - -// Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) - - form := url.Values{} - form.Add("authorName", s.AuthorName) - form.Add("title", s.Title) - form.Add("text", text) - if s.RedirectURL != "" { - form.Add("redirectUrl", s.RedirectURL) - } - if s.ImageURL != "" { - form.Add("imageUrl", s.ImageURL) - } - - resp, err := http.PostForm(s.WebhookURL, form) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) - } - return nil -} - -// Flush implementing method. empty. -func (s *JLWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *JLWriter) Destroy() { -} - -func init() { - Register(AdapterJianLiao, newJLWriter) -} diff --git a/logs/log.go b/logs/log.go deleted file mode 100644 index 39c006d2..00000000 --- a/logs/log.go +++ /dev/null @@ -1,669 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package logs provide a general log interface -// Usage: -// -// import "github.com/astaxie/beego/logs" -// -// log := NewLogger(10000) -// log.SetLogger("console", "") -// -// > the first params stand for how many channel -// -// Use it like this: -// -// log.Trace("trace") -// log.Info("info") -// log.Warn("warning") -// log.Debug("debug") -// log.Critical("critical") -// -// more docs http://beego.me/docs/module/logs.md -package logs - -import ( - "fmt" - "log" - "os" - "path" - "runtime" - "strconv" - "strings" - "sync" - "time" -) - -// RFC5424 log message levels. -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -// levelLogLogger is defined to implement log.Logger -// the real log level will be LevelEmergency -const levelLoggerImpl = -1 - -// Name for adapter with beego official support -const ( - AdapterConsole = "console" - AdapterFile = "file" - AdapterMultiFile = "multifile" - AdapterMail = "smtp" - AdapterConn = "conn" - AdapterEs = "es" - AdapterJianLiao = "jianliao" - AdapterSlack = "slack" - AdapterAliLS = "alils" -) - -// Legacy log level constants to ensure backwards compatibility. -const ( - LevelInfo = LevelInformational - LevelTrace = LevelDebug - LevelWarn = LevelWarning -) - -type newLoggerFunc func() Logger - -// Logger defines the behavior of a log provider. -type Logger interface { - Init(config string) error - WriteMsg(when time.Time, msg string, level int) error - Destroy() - Flush() -} - -var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} - -// Register makes a log provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, log newLoggerFunc) { - if log == nil { - panic("logs: Register provide is nil") - } - if _, dup := adapters[name]; dup { - panic("logs: Register called twice for provider " + name) - } - adapters[name] = log -} - -// BeeLogger is default logger in beego application. -// it can contain several providers and log message into all providers. -type BeeLogger struct { - lock sync.Mutex - level int - init bool - enableFuncCallDepth bool - loggerFuncCallDepth int - asynchronous bool - prefix string - msgChanLen int64 - msgChan chan *logMsg - signalChan chan string - wg sync.WaitGroup - outputs []*nameLogger -} - -const defaultAsyncMsgLen = 1e3 - -type nameLogger struct { - Logger - name string -} - -type logMsg struct { - level int - msg string - when time.Time -} - -var logMsgPool *sync.Pool - -// NewLogger returns a new BeeLogger. -// channelLen means the number of messages in chan(used where asynchronous is true). -// if the buffering chan is full, logger adapters write to file or other way. -func NewLogger(channelLens ...int64) *BeeLogger { - bl := new(BeeLogger) - bl.level = LevelDebug - bl.loggerFuncCallDepth = 2 - bl.msgChanLen = append(channelLens, 0)[0] - if bl.msgChanLen <= 0 { - bl.msgChanLen = defaultAsyncMsgLen - } - bl.signalChan = make(chan string, 1) - bl.setLogger(AdapterConsole) - return bl -} - -// Async set the log to asynchronous and start the goroutine -func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { - bl.lock.Lock() - defer bl.lock.Unlock() - if bl.asynchronous { - return bl - } - bl.asynchronous = true - if len(msgLen) > 0 && msgLen[0] > 0 { - bl.msgChanLen = msgLen[0] - } - bl.msgChan = make(chan *logMsg, bl.msgChanLen) - logMsgPool = &sync.Pool{ - New: func() interface{} { - return &logMsg{} - }, - } - bl.wg.Add(1) - go bl.startLogger() - return bl -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { - config := append(configs, "{}")[0] - for _, l := range bl.outputs { - if l.name == adapterName { - return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) - } - } - - logAdapter, ok := adapters[adapterName] - if !ok { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - - lg := logAdapter() - err := lg.Init(config) - if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) - return err - } - bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) - return nil -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - if !bl.init { - bl.outputs = []*nameLogger{} - bl.init = true - } - return bl.setLogger(adapterName, configs...) -} - -// DelLogger remove a logger adapter in BeeLogger. -func (bl *BeeLogger) DelLogger(adapterName string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - outputs := []*nameLogger{} - for _, lg := range bl.outputs { - if lg.name == adapterName { - lg.Destroy() - } else { - outputs = append(outputs, lg) - } - } - if len(outputs) == len(bl.outputs) { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - bl.outputs = outputs - return nil -} - -func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { - for _, l := range bl.outputs { - err := l.WriteMsg(when, msg, level) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) - } - } -} - -func (bl *BeeLogger) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - // writeMsg will always add a '\n' character - if p[len(p)-1] == '\n' { - p = p[0 : len(p)-1] - } - // set levelLoggerImpl to ensure all log message will be write out - err = bl.writeMsg(levelLoggerImpl, string(p)) - if err == nil { - return len(p), err - } - return 0, err -} - -func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { - if !bl.init { - bl.lock.Lock() - bl.setLogger(AdapterConsole) - bl.lock.Unlock() - } - - if len(v) > 0 { - msg = fmt.Sprintf(msg, v...) - } - - msg = bl.prefix + " " + msg - - when := time.Now() - if bl.enableFuncCallDepth { - _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg - } - - //set level info in front of filename info - if logLevel == levelLoggerImpl { - // set to emergency to ensure all log will be print out correctly - logLevel = LevelEmergency - } else { - msg = levelPrefix[logLevel] + " " + msg - } - - if bl.asynchronous { - lm := logMsgPool.Get().(*logMsg) - lm.level = logLevel - lm.msg = msg - lm.when = when - if bl.outputs != nil { - bl.msgChan <- lm - } else { - logMsgPool.Put(lm) - } - } else { - bl.writeToLoggers(when, msg, logLevel) - } - return nil -} - -// SetLevel Set log message level. -// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), -// log providers will not even be sent the message. -func (bl *BeeLogger) SetLevel(l int) { - bl.level = l -} - -// GetLevel Get Current log message level. -func (bl *BeeLogger) GetLevel() int { - return bl.level -} - -// SetLogFuncCallDepth set log funcCallDepth -func (bl *BeeLogger) SetLogFuncCallDepth(d int) { - bl.loggerFuncCallDepth = d -} - -// GetLogFuncCallDepth return log funcCallDepth for wrapper -func (bl *BeeLogger) GetLogFuncCallDepth() int { - return bl.loggerFuncCallDepth -} - -// EnableFuncCallDepth enable log funcCallDepth -func (bl *BeeLogger) EnableFuncCallDepth(b bool) { - bl.enableFuncCallDepth = b -} - -// set prefix -func (bl *BeeLogger) SetPrefix(s string) { - bl.prefix = s -} - -// start logger chan reading. -// when chan is not empty, write logs. -func (bl *BeeLogger) startLogger() { - gameOver := false - for { - select { - case bm := <-bl.msgChan: - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - case sg := <-bl.signalChan: - // Now should only send "flush" or "close" to bl.signalChan - bl.flush() - if sg == "close" { - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil - gameOver = true - } - bl.wg.Done() - } - if gameOver { - break - } - } -} - -// Emergency Log EMERGENCY level message. -func (bl *BeeLogger) Emergency(format string, v ...interface{}) { - if LevelEmergency > bl.level { - return - } - bl.writeMsg(LevelEmergency, format, v...) -} - -// Alert Log ALERT level message. -func (bl *BeeLogger) Alert(format string, v ...interface{}) { - if LevelAlert > bl.level { - return - } - bl.writeMsg(LevelAlert, format, v...) -} - -// Critical Log CRITICAL level message. -func (bl *BeeLogger) Critical(format string, v ...interface{}) { - if LevelCritical > bl.level { - return - } - bl.writeMsg(LevelCritical, format, v...) -} - -// Error Log ERROR level message. -func (bl *BeeLogger) Error(format string, v ...interface{}) { - if LevelError > bl.level { - return - } - bl.writeMsg(LevelError, format, v...) -} - -// Warning Log WARNING level message. -func (bl *BeeLogger) Warning(format string, v ...interface{}) { - if LevelWarn > bl.level { - return - } - bl.writeMsg(LevelWarn, format, v...) -} - -// Notice Log NOTICE level message. -func (bl *BeeLogger) Notice(format string, v ...interface{}) { - if LevelNotice > bl.level { - return - } - bl.writeMsg(LevelNotice, format, v...) -} - -// Informational Log INFORMATIONAL level message. -func (bl *BeeLogger) Informational(format string, v ...interface{}) { - if LevelInfo > bl.level { - return - } - bl.writeMsg(LevelInfo, format, v...) -} - -// Debug Log DEBUG level message. -func (bl *BeeLogger) Debug(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - bl.writeMsg(LevelDebug, format, v...) -} - -// Warn Log WARN level message. -// compatibility alias for Warning() -func (bl *BeeLogger) Warn(format string, v ...interface{}) { - if LevelWarn > bl.level { - return - } - bl.writeMsg(LevelWarn, format, v...) -} - -// Info Log INFO level message. -// compatibility alias for Informational() -func (bl *BeeLogger) Info(format string, v ...interface{}) { - if LevelInfo > bl.level { - return - } - bl.writeMsg(LevelInfo, format, v...) -} - -// Trace Log TRACE level message. -// compatibility alias for Debug() -func (bl *BeeLogger) Trace(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - bl.writeMsg(LevelDebug, format, v...) -} - -// Flush flush all chan data. -func (bl *BeeLogger) Flush() { - if bl.asynchronous { - bl.signalChan <- "flush" - bl.wg.Wait() - bl.wg.Add(1) - return - } - bl.flush() -} - -// Close close logger, flush all chan data and destroy all adapters in BeeLogger. -func (bl *BeeLogger) Close() { - if bl.asynchronous { - bl.signalChan <- "close" - bl.wg.Wait() - close(bl.msgChan) - } else { - bl.flush() - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil - } - close(bl.signalChan) -} - -// Reset close all outputs, and set bl.outputs to nil -func (bl *BeeLogger) Reset() { - bl.Flush() - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil -} - -func (bl *BeeLogger) flush() { - if bl.asynchronous { - for { - if len(bl.msgChan) > 0 { - bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - continue - } - break - } - } - for _, l := range bl.outputs { - l.Flush() - } -} - -// beeLogger references the used application logger. -var beeLogger = NewLogger() - -// GetBeeLogger returns the default BeeLogger -func GetBeeLogger() *BeeLogger { - return beeLogger -} - -var beeLoggerMap = struct { - sync.RWMutex - logs map[string]*log.Logger -}{ - logs: map[string]*log.Logger{}, -} - -// GetLogger returns the default BeeLogger -func GetLogger(prefixes ...string) *log.Logger { - prefix := append(prefixes, "")[0] - if prefix != "" { - prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) - } - beeLoggerMap.RLock() - l, ok := beeLoggerMap.logs[prefix] - if ok { - beeLoggerMap.RUnlock() - return l - } - beeLoggerMap.RUnlock() - beeLoggerMap.Lock() - defer beeLoggerMap.Unlock() - l, ok = beeLoggerMap.logs[prefix] - if !ok { - l = log.New(beeLogger, prefix, 0) - beeLoggerMap.logs[prefix] = l - } - return l -} - -// Reset will remove all the adapter -func Reset() { - beeLogger.Reset() -} - -// Async set the beelogger with Async mode and hold msglen messages -func Async(msgLen ...int64) *BeeLogger { - return beeLogger.Async(msgLen...) -} - -// SetLevel sets the global log level used by the simple logger. -func SetLevel(l int) { - beeLogger.SetLevel(l) -} - -// SetPrefix sets the prefix -func SetPrefix(s string) { - beeLogger.SetPrefix(s) -} - -// EnableFuncCallDepth enable log funcCallDepth -func EnableFuncCallDepth(b bool) { - beeLogger.enableFuncCallDepth = b -} - -// SetLogFuncCall set the CallDepth, default is 4 -func SetLogFuncCall(b bool) { - beeLogger.EnableFuncCallDepth(b) - beeLogger.SetLogFuncCallDepth(4) -} - -// SetLogFuncCallDepth set log funcCallDepth -func SetLogFuncCallDepth(d int) { - beeLogger.loggerFuncCallDepth = d -} - -// SetLogger sets a new logger. -func SetLogger(adapter string, config ...string) error { - return beeLogger.SetLogger(adapter, config...) -} - -// Emergency logs a message at emergency level. -func Emergency(f interface{}, v ...interface{}) { - beeLogger.Emergency(formatLog(f, v...)) -} - -// Alert logs a message at alert level. -func Alert(f interface{}, v ...interface{}) { - beeLogger.Alert(formatLog(f, v...)) -} - -// Critical logs a message at critical level. -func Critical(f interface{}, v ...interface{}) { - beeLogger.Critical(formatLog(f, v...)) -} - -// Error logs a message at error level. -func Error(f interface{}, v ...interface{}) { - beeLogger.Error(formatLog(f, v...)) -} - -// Warning logs a message at warning level. -func Warning(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) -} - -// Warn compatibility alias for Warning() -func Warn(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) -} - -// Notice logs a message at notice level. -func Notice(f interface{}, v ...interface{}) { - beeLogger.Notice(formatLog(f, v...)) -} - -// Informational logs a message at info level. -func Informational(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) -} - -// Info compatibility alias for Warning() -func Info(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) -} - -// Debug logs a message at debug level. -func Debug(f interface{}, v ...interface{}) { - beeLogger.Debug(formatLog(f, v...)) -} - -// Trace logs a message at trace level. -// compatibility alias for Warning() -func Trace(f interface{}, v ...interface{}) { - beeLogger.Trace(formatLog(f, v...)) -} - -func formatLog(f interface{}, v ...interface{}) string { - var msg string - switch f.(type) { - case string: - msg = f.(string) - if len(v) == 0 { - return msg - } - if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { - //format string - } else { - //do not contain format char - msg += strings.Repeat(" %v", len(v)) - } - default: - msg = fmt.Sprint(f) - if len(v) == 0 { - return msg - } - msg += strings.Repeat(" %v", len(v)) - } - return fmt.Sprintf(msg, v...) -} diff --git a/logs/logger.go b/logs/logger.go deleted file mode 100644 index a28bff6f..00000000 --- a/logs/logger.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "io" - "runtime" - "sync" - "time" -) - -type logWriter struct { - sync.Mutex - writer io.Writer -} - -func newLogWriter(wr io.Writer) *logWriter { - return &logWriter{writer: wr} -} - -func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { - lg.Lock() - h, _, _ := formatTimeHeader(when) - n, err := lg.writer.Write(append(append(h, msg...), '\n')) - lg.Unlock() - return n, err -} - -const ( - y1 = `0123456789` - y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` - y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` - y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` - mo1 = `000000000111` - mo2 = `123456789012` - d1 = `0000000001111111111222222222233` - d2 = `1234567890123456789012345678901` - h1 = `000000000011111111112222` - h2 = `012345678901234567890123` - mi1 = `000000000011111111112222222222333333333344444444445555555555` - mi2 = `012345678901234567890123456789012345678901234567890123456789` - s1 = `000000000011111111112222222222333333333344444444445555555555` - s2 = `012345678901234567890123456789012345678901234567890123456789` - ns1 = `0123456789` -) - -func formatTimeHeader(when time.Time) ([]byte, int, int) { - y, mo, d := when.Date() - h, mi, s := when.Clock() - ns := when.Nanosecond() / 1000000 - //len("2006/01/02 15:04:05.123 ")==24 - var buf [24]byte - - buf[0] = y1[y/1000%10] - buf[1] = y2[y/100] - buf[2] = y3[y-y/100*100] - buf[3] = y4[y-y/100*100] - buf[4] = '/' - buf[5] = mo1[mo-1] - buf[6] = mo2[mo-1] - buf[7] = '/' - buf[8] = d1[d-1] - buf[9] = d2[d-1] - buf[10] = ' ' - buf[11] = h1[h] - buf[12] = h2[h] - buf[13] = ':' - buf[14] = mi1[mi] - buf[15] = mi2[mi] - buf[16] = ':' - buf[17] = s1[s] - buf[18] = s2[s] - buf[19] = '.' - buf[20] = ns1[ns/100] - buf[21] = ns1[ns%100/10] - buf[22] = ns1[ns%10] - - buf[23] = ' ' - - return buf[0:], d, h -} - -var ( - green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) - white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) - yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) - red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) - blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) - magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) - cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) - - w32Green = string([]byte{27, 91, 52, 50, 109}) - w32White = string([]byte{27, 91, 52, 55, 109}) - w32Yellow = string([]byte{27, 91, 52, 51, 109}) - w32Red = string([]byte{27, 91, 52, 49, 109}) - w32Blue = string([]byte{27, 91, 52, 52, 109}) - w32Magenta = string([]byte{27, 91, 52, 53, 109}) - w32Cyan = string([]byte{27, 91, 52, 54, 109}) - - reset = string([]byte{27, 91, 48, 109}) -) - -var once sync.Once -var colorMap map[string]string - -func initColor() { - if runtime.GOOS == "windows" { - green = w32Green - white = w32White - yellow = w32Yellow - red = w32Red - blue = w32Blue - magenta = w32Magenta - cyan = w32Cyan - } - colorMap = map[string]string{ - //by color - "green": green, - "white": white, - "yellow": yellow, - "red": red, - //by method - "GET": blue, - "POST": cyan, - "PUT": yellow, - "DELETE": red, - "PATCH": green, - "HEAD": magenta, - "OPTIONS": white, - } -} - -// ColorByStatus return color by http code -// 2xx return Green -// 3xx return White -// 4xx return Yellow -// 5xx return Red -func ColorByStatus(code int) string { - once.Do(initColor) - switch { - case code >= 200 && code < 300: - return colorMap["green"] - case code >= 300 && code < 400: - return colorMap["white"] - case code >= 400 && code < 500: - return colorMap["yellow"] - default: - return colorMap["red"] - } -} - -// ColorByMethod return color by http code -func ColorByMethod(method string) string { - once.Do(initColor) - if c := colorMap[method]; c != "" { - return c - } - return reset -} - -// ResetColor return reset color -func ResetColor() string { - return reset -} diff --git a/logs/logger_test.go b/logs/logger_test.go deleted file mode 100644 index 15be500d..00000000 --- a/logs/logger_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "testing" - "time" -) - -func TestFormatHeader_0(t *testing.T) { - tm := time.Now() - if tm.Year() >= 2100 { - t.FailNow() - } - dur := time.Second - for { - if tm.Year() >= 2100 { - break - } - h, _, _ := formatTimeHeader(tm) - if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { - t.Log(tm) - t.FailNow() - } - tm = tm.Add(dur) - dur *= 2 - } -} - -func TestFormatHeader_1(t *testing.T) { - tm := time.Now() - year := tm.Year() - dur := time.Second - for { - if tm.Year() >= year+1 { - break - } - h, _, _ := formatTimeHeader(tm) - if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { - t.Log(tm) - t.FailNow() - } - tm = tm.Add(dur) - } -} diff --git a/logs/multifile.go b/logs/multifile.go deleted file mode 100644 index 90168274..00000000 --- a/logs/multifile.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "time" -) - -// A filesLogWriter manages several fileLogWriter -// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file -// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log -// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log -// the rotate attribute also acts like fileLogWriter -type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` -} - -var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} - -// Init file logger with json config. -// jsonConfig like: -// { -// "filename":"logs/beego.log", -// "maxLines":0, -// "maxsize":0, -// "daily":true, -// "maxDays":15, -// "rotate":true, -// "perm":0600, -// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], -// } - -func (f *multiFileLogWriter) Init(config string) error { - writer := newFileWriter().(*fileLogWriter) - err := writer.Init(config) - if err != nil { - return err - } - f.fullLogWriter = writer - f.writers[LevelDebug+1] = writer - - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) - - jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) - - for i := LevelEmergency; i < LevelDebug+1; i++ { - for _, v := range f.Separate { - if v == levelNames[i] { - jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix - jsonMap["level"] = i - bs, _ := json.Marshal(jsonMap) - writer = newFileWriter().(*fileLogWriter) - err := writer.Init(string(bs)) - if err != nil { - return err - } - f.writers[i] = writer - } - } - } - - return nil -} - -func (f *multiFileLogWriter) Destroy() { - for i := 0; i < len(f.writers); i++ { - if f.writers[i] != nil { - f.writers[i].Destroy() - } - } -} - -func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if f.fullLogWriter != nil { - f.fullLogWriter.WriteMsg(when, msg, level) - } - for i := 0; i < len(f.writers)-1; i++ { - if f.writers[i] != nil { - if level == f.writers[i].Level { - f.writers[i].WriteMsg(when, msg, level) - } - } - } - return nil -} - -func (f *multiFileLogWriter) Flush() { - for i := 0; i < len(f.writers); i++ { - if f.writers[i] != nil { - f.writers[i].Flush() - } - } -} - -// newFilesWriter create a FileLogWriter returning as LoggerInterface. -func newFilesWriter() Logger { - return &multiFileLogWriter{} -} - -func init() { - Register(AdapterMultiFile, newFilesWriter) -} diff --git a/logs/multifile_test.go b/logs/multifile_test.go deleted file mode 100644 index 57b96094..00000000 --- a/logs/multifile_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bufio" - "os" - "strconv" - "strings" - "testing" -) - -func TestFiles_1(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - fns := []string{""} - fns = append(fns, levelNames[0:]...) - name := "test" - suffix := ".log" - for _, fn := range fns { - - file := name + suffix - if fn != "" { - file = name + "." + fn + suffix - } - f, err := os.Open(file) - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - lastLine := "" - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lastLine = string(line) - lineNum++ - } - } - var expected = 1 - if fn == "" { - expected = LevelDebug + 1 - } - if lineNum != expected { - t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines") - } - if lineNum == 1 { - if !strings.Contains(lastLine, fn) { - t.Fatal(file + " " + lastLine + " not contains the log msg " + fn) - } - } - os.Remove(file) - } - -} diff --git a/logs/slack.go b/logs/slack.go deleted file mode 100644 index 1cd2e5ae..00000000 --- a/logs/slack.go +++ /dev/null @@ -1,60 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "time" -) - -// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook -type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` -} - -// newSLACKWriter create jiaoliao writer. -func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} -} - -// Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) - - form := url.Values{} - form.Add("payload", text) - - resp, err := http.PostForm(s.WebhookURL, form) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) - } - return nil -} - -// Flush implementing method. empty. -func (s *SLACKWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *SLACKWriter) Destroy() { -} - -func init() { - Register(AdapterSlack, newSLACKWriter) -} diff --git a/logs/smtp.go b/logs/smtp.go deleted file mode 100644 index 6208d7b8..00000000 --- a/logs/smtp.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "crypto/tls" - "encoding/json" - "fmt" - "net" - "net/smtp" - "strings" - "time" -) - -// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. -type SMTPWriter struct { - Username string `json:"username"` - Password string `json:"password"` - Host string `json:"host"` - Subject string `json:"subject"` - FromAddress string `json:"fromAddress"` - RecipientAddresses []string `json:"sendTos"` - Level int `json:"level"` -} - -// NewSMTPWriter create smtp writer. -func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} -} - -// Init smtp writer with json config. -// config like: -// { -// "username":"example@gmail.com", -// "password:"password", -// "host":"smtp.gmail.com:465", -// "subject":"email title", -// "fromAddress":"from@example.com", -// "sendTos":["email1","email2"], -// "level":LevelError -// } -func (s *SMTPWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { - if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { - return nil - } - return smtp.PlainAuth( - "", - s.Username, - s.Password, - host, - ) -} - -func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { - client, err := smtp.Dial(hostAddressWithPort) - if err != nil { - return err - } - - host, _, _ := net.SplitHostPort(hostAddressWithPort) - tlsConn := &tls.Config{ - InsecureSkipVerify: true, - ServerName: host, - } - if err = client.StartTLS(tlsConn); err != nil { - return err - } - - if auth != nil { - if err = client.Auth(auth); err != nil { - return err - } - } - - if err = client.Mail(fromAddress); err != nil { - return err - } - - for _, rec := range recipients { - if err = client.Rcpt(rec); err != nil { - return err - } - } - - w, err := client.Data() - if err != nil { - return err - } - _, err = w.Write(msgContent) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return client.Quit() -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - hp := strings.Split(s.Host, ":") - - // Set up authentication information. - auth := s.getSMTPAuth(hp[0]) - - // Connect to the server, authenticate, set the sender and recipient, - // and send the email all in one step. - contentType := "Content-Type: text/plain" + "; charset=UTF-8" - mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) - - return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) -} - -// Flush implementing method. empty. -func (s *SMTPWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *SMTPWriter) Destroy() { -} - -func init() { - Register(AdapterMail, newSMTPWriter) -} diff --git a/logs/smtp_test.go b/logs/smtp_test.go deleted file mode 100644 index ebc8a952..00000000 --- a/logs/smtp_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -// it often failed. And we moved this to pkg/logs, -// so we ignore it -// func TestSmtp(t *testing.T) { -// log := NewLogger(10000) -// log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) -// log.Critical("sendmail critical") -// time.Sleep(time.Second * 30) -// } diff --git a/metric/prometheus.go b/metric/prometheus.go deleted file mode 100644 index 215896bd..00000000 --- a/metric/prometheus.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "net/http" - "reflect" - "strconv" - "strings" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/logs" -) - -// Deprecated: we will removed this function in 2.1.0 -// please use pkg/web/filter/prometheus#FilterChain -func PrometheusMiddleWare(next http.Handler) http.Handler { - summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "beego", - Subsystem: "http_request", - ConstLabels: map[string]string{ - "server": beego.BConfig.ServerName, - "env": beego.BConfig.RunMode, - "appname": beego.BConfig.AppName, - }, - Help: "The statics info for http request", - }, []string{"pattern", "method", "status", "duration"}) - - prometheus.MustRegister(summaryVec) - - registerBuildInfo() - - return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { - start := time.Now() - next.ServeHTTP(writer, q) - end := time.Now() - go report(end.Sub(start), writer, q, summaryVec) - }) -} - -func registerBuildInfo() { - buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: "beego", - Subsystem: "build_info", - Help: "The building information", - ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, - "build_version": beego.BuildVersion, - "build_revision": beego.BuildGitRevision, - "build_status": beego.BuildStatus, - "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), - "go_version": beego.GoVersion, - "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), - }, - }, []string{}) - - prometheus.MustRegister(buildInfo) - buildInfo.WithLabelValues().Set(1) -} - -func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { - ctrl := beego.BeeApp.Handlers - ctx := ctrl.GetContext() - ctx.Reset(writer, q) - defer ctrl.GiveBackContext(ctx) - - // We cannot read the status code from q.Response.StatusCode - // since the http server does not set q.Response. So q.Response is nil - // Thus, we use reflection to read the status from writer whose concrete type is http.response - responseVal := reflect.ValueOf(writer).Elem() - field := responseVal.FieldByName("status") - status := -1 - if field.IsValid() && field.Kind() == reflect.Int { - status = int(field.Int()) - } - ptn := "UNKNOWN" - if rt, found := ctrl.FindRouter(ctx); found { - ptn = rt.GetPattern() - } else { - logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) - } - ms := dur / time.Millisecond - vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) -} diff --git a/metric/prometheus_test.go b/metric/prometheus_test.go deleted file mode 100644 index d82a6dec..00000000 --- a/metric/prometheus_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "net/http" - "net/url" - "testing" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/astaxie/beego/context" -) - -func TestPrometheusMiddleWare(t *testing.T) { - middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) - writer := &context.Response{} - request := &http.Request{ - URL: &url.URL{ - Host: "localhost", - RawPath: "/a/b/c", - }, - Method: "POST", - } - vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) - - report(time.Second, writer, request, vec) - middleware.ServeHTTP(writer, request) -} diff --git a/migration/ddl.go b/migration/ddl.go deleted file mode 100644 index cd2c1c49..00000000 --- a/migration/ddl.go +++ /dev/null @@ -1,395 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migration - -import ( - "fmt" - - "github.com/astaxie/beego/logs" -) - -// Index struct defines the structure of Index Columns -type Index struct { - Name string -} - -// Unique struct defines a single unique key combination -type Unique struct { - Definition string - Columns []*Column -} - -//Column struct defines a single column of a table -type Column struct { - Name string - Inc string - Null string - Default string - Unsign string - DataType string - remove bool - Modify bool -} - -// Foreign struct defines a single foreign relationship -type Foreign struct { - ForeignTable string - ForeignColumn string - OnDelete string - OnUpdate string - Column -} - -// RenameColumn struct allows renaming of columns -type RenameColumn struct { - OldName string - OldNull string - OldDefault string - OldUnsign string - OldDataType string - NewName string - Column -} - -// CreateTable creates the table on system -func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { - m.TableName = tablename - m.Engine = engine - m.Charset = charset - m.ModifyType = "create" -} - -// AlterTable set the ModifyType to alter -func (m *Migration) AlterTable(tablename string) { - m.TableName = tablename - m.ModifyType = "alter" -} - -// NewCol creates a new standard column and attaches it to m struct -func (m *Migration) NewCol(name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - return col -} - -//PriCol creates a new primary column and attaches it to m struct -func (m *Migration) PriCol(name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - m.AddPrimary(col) - return col -} - -//UniCol creates / appends columns to specified unique key and attaches it to m struct -func (m *Migration) UniCol(uni, name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - - uniqueOriginal := &Unique{} - - for _, unique := range m.Uniques { - if unique.Definition == uni { - unique.AddColumnsToUnique(col) - uniqueOriginal = unique - } - } - if uniqueOriginal.Definition == "" { - unique := &Unique{Definition: uni} - unique.AddColumnsToUnique(col) - m.AddUnique(unique) - } - - return col -} - -//ForeignCol creates a new foreign column and returns the instance of column -func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { - - foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} - foreign.Name = colname - m.AddForeign(foreign) - return foreign -} - -//SetOnDelete sets the on delete of foreign -func (foreign *Foreign) SetOnDelete(del string) *Foreign { - foreign.OnDelete = "ON DELETE" + del - return foreign -} - -//SetOnUpdate sets the on update of foreign -func (foreign *Foreign) SetOnUpdate(update string) *Foreign { - foreign.OnUpdate = "ON UPDATE" + update - return foreign -} - -//Remove marks the columns to be removed. -//it allows reverse m to create the column. -func (c *Column) Remove() { - c.remove = true -} - -//SetAuto enables auto_increment of column (can be used once) -func (c *Column) SetAuto(inc bool) *Column { - if inc { - c.Inc = "auto_increment" - } - return c -} - -//SetNullable sets the column to be null -func (c *Column) SetNullable(null bool) *Column { - if null { - c.Null = "" - - } else { - c.Null = "NOT NULL" - } - return c -} - -//SetDefault sets the default value, prepend with "DEFAULT " -func (c *Column) SetDefault(def string) *Column { - c.Default = "DEFAULT " + def - return c -} - -//SetUnsigned sets the column to be unsigned int -func (c *Column) SetUnsigned(unsign bool) *Column { - if unsign { - c.Unsign = "UNSIGNED" - } - return c -} - -//SetDataType sets the dataType of the column -func (c *Column) SetDataType(dataType string) *Column { - c.DataType = dataType - return c -} - -//SetOldNullable allows reverting to previous nullable on reverse ms -func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { - if null { - c.OldNull = "" - - } else { - c.OldNull = "NOT NULL" - } - return c -} - -//SetOldDefault allows reverting to previous default on reverse ms -func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { - c.OldDefault = def - return c -} - -//SetOldUnsigned allows reverting to previous unsgined on reverse ms -func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { - if unsign { - c.OldUnsign = "UNSIGNED" - } - return c -} - -//SetOldDataType allows reverting to previous datatype on reverse ms -func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { - c.OldDataType = dataType - return c -} - -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) -func (c *Column) SetPrimary(m *Migration) *Column { - m.Primary = append(m.Primary, c) - return c -} - -//AddColumnsToUnique adds the columns to Unique Struct -func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { - - unique.Columns = append(unique.Columns, columns...) - - return unique -} - -//AddColumns adds columns to m struct -func (m *Migration) AddColumns(columns ...*Column) *Migration { - - m.Columns = append(m.Columns, columns...) - - return m -} - -//AddPrimary adds the column to primary in m struct -func (m *Migration) AddPrimary(primary *Column) *Migration { - m.Primary = append(m.Primary, primary) - return m -} - -//AddUnique adds the column to unique in m struct -func (m *Migration) AddUnique(unique *Unique) *Migration { - m.Uniques = append(m.Uniques, unique) - return m -} - -//AddForeign adds the column to foreign in m struct -func (m *Migration) AddForeign(foreign *Foreign) *Migration { - m.Foreigns = append(m.Foreigns, foreign) - return m -} - -//AddIndex adds the column to index in m struct -func (m *Migration) AddIndex(index *Index) *Migration { - m.Indexes = append(m.Indexes, index) - return m -} - -//RenameColumn allows renaming of columns -func (m *Migration) RenameColumn(from, to string) *RenameColumn { - rename := &RenameColumn{OldName: from, NewName: to} - m.Renames = append(m.Renames, rename) - return rename -} - -//GetSQL returns the generated sql depending on ModifyType -func (m *Migration) GetSQL() (sql string) { - sql = "" - switch m.ModifyType { - case "create": - { - sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName) - for index, column := range m.Columns { - sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - if len(m.Columns) > index+1 { - sql += "," - } - } - - if len(m.Primary) > 0 { - sql += fmt.Sprintf(",\n PRIMARY KEY( ") - } - for index, column := range m.Primary { - sql += fmt.Sprintf(" `%s`", column.Name) - if len(m.Primary) > index+1 { - sql += "," - } - - } - if len(m.Primary) > 0 { - sql += fmt.Sprintf(")") - } - - for _, unique := range m.Uniques { - sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition) - for index, column := range unique.Columns { - sql += fmt.Sprintf(" `%s`", column.Name) - if len(unique.Columns) > index+1 { - sql += "," - } - } - sql += fmt.Sprintf(")") - } - for _, foreign := range m.Foreigns { - sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) - sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name) - sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) - - } - sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset) - break - } - case "alter": - { - sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) - for index, column := range m.Columns { - if !column.remove { - logs.Info("col") - sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - } else { - sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) - } - - if len(m.Columns) > index+1 { - sql += "," - } - } - for index, column := range m.Renames { - sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - if len(m.Renames) > index+1 { - sql += "," - } - } - - for index, foreign := range m.Foreigns { - sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) - sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name) - sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) - if len(m.Foreigns) > index+1 { - sql += "," - } - } - sql += ";" - - break - } - case "reverse": - { - - sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName) - for index, column := range m.Columns { - if column.remove { - sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - } else { - sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) - } - if len(m.Columns) > index+1 { - sql += "," - } - } - - if len(m.Primary) > 0 { - sql += fmt.Sprintf("\n DROP PRIMARY KEY,") - } - - for index, unique := range m.Uniques { - sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) - if len(m.Uniques) > index+1 { - sql += "," - } - - } - for index, column := range m.Renames { - sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) - if len(m.Renames) > index+1 { - sql += "," - } - } - - for _, foreign := range m.Foreigns { - sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) - sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) - sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name) - } - sql += ";" - } - case "delete": - { - sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName) - } - } - - return -} diff --git a/migration/doc.go b/migration/doc.go deleted file mode 100644 index 0c6564d4..00000000 --- a/migration/doc.go +++ /dev/null @@ -1,32 +0,0 @@ -// Package migration enables you to generate migrations back and forth. It generates both migrations. -// -// //Creates a table -// m.CreateTable("tablename","InnoDB","utf8"); -// -// //Alter a table -// m.AlterTable("tablename") -// -// Standard Column Methods -// * SetDataType -// * SetNullable -// * SetDefault -// * SetUnsigned (use only on integer types unless produces error) -// -// //Sets a primary column, multiple calls allowed, standard column methods available -// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) -// -// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index -// m.UniCol("index","column") -// -// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove -// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) -// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) -// -// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to -// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") -// m.RenameColumn("from","to")... -// -// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. -// //Supports standard column methods, automatic reverse. -// m.ForeignCol("local_col","foreign_col","foreign_table") -package migration diff --git a/migration/migration.go b/migration/migration.go deleted file mode 100644 index 5ddfd972..00000000 --- a/migration/migration.go +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package migration is used for migration -// -// The table structure is as follow: -// -// CREATE TABLE `migrations` ( -// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', -// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', -// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', -// `statements` longtext COMMENT 'SQL statements for this migration', -// `rollback_statements` longtext, -// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', -// PRIMARY KEY (`id_migration`) -// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; -package migration - -import ( - "errors" - "sort" - "strings" - "time" - - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/orm" -) - -// const the data format for the bee generate migration datatype -const ( - DateFormat = "20060102_150405" - DBDateFormat = "2006-01-02 15:04:05" -) - -// Migrationer is an interface for all Migration struct -type Migrationer interface { - Up() - Down() - Reset() - Exec(name, status string) error - GetCreated() int64 -} - -//Migration defines the migrations by either SQL or DDL -type Migration struct { - sqls []string - Created string - TableName string - Engine string - Charset string - ModifyType string - Columns []*Column - Indexes []*Index - Primary []*Column - Uniques []*Unique - Foreigns []*Foreign - Renames []*RenameColumn - RemoveColumns []*Column - RemoveIndexes []*Index - RemoveUniques []*Unique - RemoveForeigns []*Foreign -} - -var ( - migrationMap map[string]Migrationer -) - -func init() { - migrationMap = make(map[string]Migrationer) -} - -// Up implement in the Inheritance struct for upgrade -func (m *Migration) Up() { - - switch m.ModifyType { - case "reverse": - m.ModifyType = "alter" - case "delete": - m.ModifyType = "create" - } - m.sqls = append(m.sqls, m.GetSQL()) -} - -// Down implement in the Inheritance struct for down -func (m *Migration) Down() { - - switch m.ModifyType { - case "alter": - m.ModifyType = "reverse" - case "create": - m.ModifyType = "delete" - } - m.sqls = append(m.sqls, m.GetSQL()) -} - -//Migrate adds the SQL to the execution list -func (m *Migration) Migrate(migrationType string) { - m.ModifyType = migrationType - m.sqls = append(m.sqls, m.GetSQL()) -} - -// SQL add sql want to execute -func (m *Migration) SQL(sql string) { - m.sqls = append(m.sqls, sql) -} - -// Reset the sqls -func (m *Migration) Reset() { - m.sqls = make([]string, 0) -} - -// Exec execute the sql already add in the sql -func (m *Migration) Exec(name, status string) error { - o := orm.NewOrm() - for _, s := range m.sqls { - logs.Info("exec sql:", s) - r := o.Raw(s) - _, err := r.Exec() - if err != nil { - return err - } - } - return m.addOrUpdateRecord(name, status) -} - -func (m *Migration) addOrUpdateRecord(name, status string) error { - o := orm.NewOrm() - if status == "down" { - status = "rollback" - p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() - if err != nil { - return nil - } - _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) - return err - } - status = "update" - p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() - if err != nil { - return err - } - _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) - return err -} - -// GetCreated get the unixtime from the Created -func (m *Migration) GetCreated() int64 { - t, err := time.Parse(DateFormat, m.Created) - if err != nil { - return 0 - } - return t.Unix() -} - -// Register register the Migration in the map -func Register(name string, m Migrationer) error { - if _, ok := migrationMap[name]; ok { - return errors.New("already exist name:" + name) - } - migrationMap[name] = m - return nil -} - -// Upgrade upgrade the migration from lasttime -func Upgrade(lasttime int64) error { - sm := sortMap(migrationMap) - i := 0 - migs, _ := getAllMigrations() - for _, v := range sm { - if _, ok := migs[v.name]; !ok { - logs.Info("start upgrade", v.name) - v.m.Reset() - v.m.Up() - err := v.m.Exec(v.name, "up") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - logs.Info("end upgrade:", v.name) - i++ - } - } - logs.Info("total success upgrade:", i, " migration") - time.Sleep(2 * time.Second) - return nil -} - -// Rollback rollback the migration by the name -func Rollback(name string) error { - if v, ok := migrationMap[name]; ok { - logs.Info("start rollback") - v.Reset() - v.Down() - err := v.Exec(name, "down") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - logs.Info("end rollback") - time.Sleep(2 * time.Second) - return nil - } - logs.Error("not exist the migrationMap name:" + name) - time.Sleep(2 * time.Second) - return errors.New("not exist the migrationMap name:" + name) -} - -// Reset reset all migration -// run all migration's down function -func Reset() error { - sm := sortMap(migrationMap) - i := 0 - for j := len(sm) - 1; j >= 0; j-- { - v := sm[j] - if isRollBack(v.name) { - logs.Info("skip the", v.name) - time.Sleep(1 * time.Second) - continue - } - logs.Info("start reset:", v.name) - v.m.Reset() - v.m.Down() - err := v.m.Exec(v.name, "down") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - i++ - logs.Info("end reset:", v.name) - } - logs.Info("total success reset:", i, " migration") - time.Sleep(2 * time.Second) - return nil -} - -// Refresh first Reset, then Upgrade -func Refresh() error { - err := Reset() - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - err = Upgrade(0) - return err -} - -type dataSlice []data - -type data struct { - created int64 - name string - m Migrationer -} - -// Len is part of sort.Interface. -func (d dataSlice) Len() int { - return len(d) -} - -// Swap is part of sort.Interface. -func (d dataSlice) Swap(i, j int) { - d[i], d[j] = d[j], d[i] -} - -// Less is part of sort.Interface. We use count as the value to sort by -func (d dataSlice) Less(i, j int) bool { - return d[i].created < d[j].created -} - -func sortMap(m map[string]Migrationer) dataSlice { - s := make(dataSlice, 0, len(m)) - for k, v := range m { - d := data{} - d.created = v.GetCreated() - d.name = k - d.m = v - s = append(s, d) - } - sort.Sort(s) - return s -} - -func isRollBack(name string) bool { - o := orm.NewOrm() - var maps []orm.Params - num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) - if err != nil { - logs.Info("get name has error", err) - return false - } - if num <= 0 { - return false - } - if maps[0]["status"] == "rollback" { - return true - } - return false -} -func getAllMigrations() (map[string]string, error) { - o := orm.NewOrm() - var maps []orm.Params - migs := make(map[string]string) - num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) - if err != nil { - logs.Info("get name has error", err) - return migs, err - } - if num > 0 { - for _, v := range maps { - name := v["name"].(string) - migs[name] = v["status"].(string) - } - } - return migs, nil -} diff --git a/mime.go b/mime.go deleted file mode 100644 index ca2878ab..00000000 --- a/mime.go +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var mimemaps = map[string]string{ - ".3dm": "x-world/x-3dmf", - ".3dmf": "x-world/x-3dmf", - ".7z": "application/x-7z-compressed", - ".a": "application/octet-stream", - ".aab": "application/x-authorware-bin", - ".aam": "application/x-authorware-map", - ".aas": "application/x-authorware-seg", - ".abc": "text/vndabc", - ".ace": "application/x-ace-compressed", - ".acgi": "text/html", - ".afl": "video/animaflex", - ".ai": "application/postscript", - ".aif": "audio/aiff", - ".aifc": "audio/aiff", - ".aiff": "audio/aiff", - ".aim": "application/x-aim", - ".aip": "text/x-audiosoft-intra", - ".alz": "application/x-alz-compressed", - ".ani": "application/x-navi-animation", - ".aos": "application/x-nokia-9000-communicator-add-on-software", - ".aps": "application/mime", - ".apk": "application/vnd.android.package-archive", - ".arc": "application/x-arc-compressed", - ".arj": "application/arj", - ".art": "image/x-jg", - ".asf": "video/x-ms-asf", - ".asm": "text/x-asm", - ".asp": "text/asp", - ".asx": "application/x-mplayer2", - ".au": "audio/basic", - ".avi": "video/x-msvideo", - ".avs": "video/avs-video", - ".bcpio": "application/x-bcpio", - ".bin": "application/mac-binary", - ".bmp": "image/bmp", - ".boo": "application/book", - ".book": "application/book", - ".boz": "application/x-bzip2", - ".bsh": "application/x-bsh", - ".bz2": "application/x-bzip2", - ".bz": "application/x-bzip", - ".c++": "text/plain", - ".c": "text/x-c", - ".cab": "application/vnd.ms-cab-compressed", - ".cat": "application/vndms-pkiseccat", - ".cc": "text/x-c", - ".ccad": "application/clariscad", - ".cco": "application/x-cocoa", - ".cdf": "application/cdf", - ".cer": "application/pkix-cert", - ".cha": "application/x-chat", - ".chat": "application/x-chat", - ".chrt": "application/vnd.kde.kchart", - ".class": "application/java", - ".com": "text/plain", - ".conf": "text/plain", - ".cpio": "application/x-cpio", - ".cpp": "text/x-c", - ".cpt": "application/mac-compactpro", - ".crl": "application/pkcs-crl", - ".crt": "application/pkix-cert", - ".crx": "application/x-chrome-extension", - ".csh": "text/x-scriptcsh", - ".css": "text/css", - ".csv": "text/csv", - ".cxx": "text/plain", - ".dar": "application/x-dar", - ".dcr": "application/x-director", - ".deb": "application/x-debian-package", - ".deepv": "application/x-deepv", - ".def": "text/plain", - ".der": "application/x-x509-ca-cert", - ".dif": "video/x-dv", - ".dir": "application/x-director", - ".divx": "video/divx", - ".dl": "video/dl", - ".dmg": "application/x-apple-diskimage", - ".doc": "application/msword", - ".dot": "application/msword", - ".dp": "application/commonground", - ".drw": "application/drafting", - ".dump": "application/octet-stream", - ".dv": "video/x-dv", - ".dvi": "application/x-dvi", - ".dwf": "drawing/x-dwf=(old)", - ".dwg": "application/acad", - ".dxf": "application/dxf", - ".dxr": "application/x-director", - ".el": "text/x-scriptelisp", - ".elc": "application/x-bytecodeelisp=(compiled=elisp)", - ".eml": "message/rfc822", - ".env": "application/x-envoy", - ".eps": "application/postscript", - ".es": "application/x-esrehber", - ".etx": "text/x-setext", - ".evy": "application/envoy", - ".exe": "application/octet-stream", - ".f77": "text/x-fortran", - ".f90": "text/x-fortran", - ".f": "text/x-fortran", - ".fdf": "application/vndfdf", - ".fif": "application/fractals", - ".fli": "video/fli", - ".flo": "image/florian", - ".flv": "video/x-flv", - ".flx": "text/vndfmiflexstor", - ".fmf": "video/x-atomic3d-feature", - ".for": "text/x-fortran", - ".fpx": "image/vndfpx", - ".frl": "application/freeloader", - ".funk": "audio/make", - ".g3": "image/g3fax", - ".g": "text/plain", - ".gif": "image/gif", - ".gl": "video/gl", - ".gsd": "audio/x-gsm", - ".gsm": "audio/x-gsm", - ".gsp": "application/x-gsp", - ".gss": "application/x-gss", - ".gtar": "application/x-gtar", - ".gz": "application/x-compressed", - ".gzip": "application/x-gzip", - ".h": "text/x-h", - ".hdf": "application/x-hdf", - ".help": "application/x-helpfile", - ".hgl": "application/vndhp-hpgl", - ".hh": "text/x-h", - ".hlb": "text/x-script", - ".hlp": "application/hlp", - ".hpg": "application/vndhp-hpgl", - ".hpgl": "application/vndhp-hpgl", - ".hqx": "application/binhex", - ".hta": "application/hta", - ".htc": "text/x-component", - ".htm": "text/html", - ".html": "text/html", - ".htmls": "text/html", - ".htt": "text/webviewhtml", - ".htx": "text/html", - ".ice": "x-conference/x-cooltalk", - ".ico": "image/x-icon", - ".ics": "text/calendar", - ".icz": "text/calendar", - ".idc": "text/plain", - ".ief": "image/ief", - ".iefs": "image/ief", - ".iges": "application/iges", - ".igs": "application/iges", - ".ima": "application/x-ima", - ".imap": "application/x-httpd-imap", - ".inf": "application/inf", - ".ins": "application/x-internett-signup", - ".ip": "application/x-ip2", - ".isu": "video/x-isvideo", - ".it": "audio/it", - ".iv": "application/x-inventor", - ".ivr": "i-world/i-vrml", - ".ivy": "application/x-livescreen", - ".jam": "audio/x-jam", - ".jav": "text/x-java-source", - ".java": "text/x-java-source", - ".jcm": "application/x-java-commerce", - ".jfif-tbnl": "image/jpeg", - ".jfif": "image/jpeg", - ".jnlp": "application/x-java-jnlp-file", - ".jpe": "image/jpeg", - ".jpeg": "image/jpeg", - ".jpg": "image/jpeg", - ".jps": "image/x-jps", - ".js": "application/javascript", - ".json": "application/json", - ".jut": "image/jutvision", - ".kar": "audio/midi", - ".karbon": "application/vnd.kde.karbon", - ".kfo": "application/vnd.kde.kformula", - ".flw": "application/vnd.kde.kivio", - ".kml": "application/vnd.google-earth.kml+xml", - ".kmz": "application/vnd.google-earth.kmz", - ".kon": "application/vnd.kde.kontour", - ".kpr": "application/vnd.kde.kpresenter", - ".kpt": "application/vnd.kde.kpresenter", - ".ksp": "application/vnd.kde.kspread", - ".kwd": "application/vnd.kde.kword", - ".kwt": "application/vnd.kde.kword", - ".ksh": "text/x-scriptksh", - ".la": "audio/nspaudio", - ".lam": "audio/x-liveaudio", - ".latex": "application/x-latex", - ".lha": "application/lha", - ".lhx": "application/octet-stream", - ".list": "text/plain", - ".lma": "audio/nspaudio", - ".log": "text/plain", - ".lsp": "text/x-scriptlisp", - ".lst": "text/plain", - ".lsx": "text/x-la-asf", - ".ltx": "application/x-latex", - ".lzh": "application/octet-stream", - ".lzx": "application/lzx", - ".m1v": "video/mpeg", - ".m2a": "audio/mpeg", - ".m2v": "video/mpeg", - ".m3u": "audio/x-mpegurl", - ".m": "text/x-m", - ".man": "application/x-troff-man", - ".manifest": "text/cache-manifest", - ".map": "application/x-navimap", - ".mar": "text/plain", - ".mbd": "application/mbedlet", - ".mc$": "application/x-magic-cap-package-10", - ".mcd": "application/mcad", - ".mcf": "text/mcf", - ".mcp": "application/netmc", - ".me": "application/x-troff-me", - ".mht": "message/rfc822", - ".mhtml": "message/rfc822", - ".mid": "application/x-midi", - ".midi": "application/x-midi", - ".mif": "application/x-frame", - ".mime": "message/rfc822", - ".mjf": "audio/x-vndaudioexplosionmjuicemediafile", - ".mjpg": "video/x-motion-jpeg", - ".mm": "application/base64", - ".mme": "application/base64", - ".mod": "audio/mod", - ".moov": "video/quicktime", - ".mov": "video/quicktime", - ".movie": "video/x-sgi-movie", - ".mp2": "audio/mpeg", - ".mp3": "audio/mpeg3", - ".mp4": "video/mp4", - ".mpa": "audio/mpeg", - ".mpc": "application/x-project", - ".mpe": "video/mpeg", - ".mpeg": "video/mpeg", - ".mpg": "video/mpeg", - ".mpga": "audio/mpeg", - ".mpp": "application/vndms-project", - ".mpt": "application/x-project", - ".mpv": "application/x-project", - ".mpx": "application/x-project", - ".mrc": "application/marc", - ".ms": "application/x-troff-ms", - ".mv": "video/x-sgi-movie", - ".my": "audio/make", - ".mzz": "application/x-vndaudioexplosionmzz", - ".nap": "image/naplps", - ".naplps": "image/naplps", - ".nc": "application/x-netcdf", - ".ncm": "application/vndnokiaconfiguration-message", - ".nif": "image/x-niff", - ".niff": "image/x-niff", - ".nix": "application/x-mix-transfer", - ".nsc": "application/x-conference", - ".nvd": "application/x-navidoc", - ".o": "application/octet-stream", - ".oda": "application/oda", - ".odb": "application/vnd.oasis.opendocument.database", - ".odc": "application/vnd.oasis.opendocument.chart", - ".odf": "application/vnd.oasis.opendocument.formula", - ".odg": "application/vnd.oasis.opendocument.graphics", - ".odi": "application/vnd.oasis.opendocument.image", - ".odm": "application/vnd.oasis.opendocument.text-master", - ".odp": "application/vnd.oasis.opendocument.presentation", - ".ods": "application/vnd.oasis.opendocument.spreadsheet", - ".odt": "application/vnd.oasis.opendocument.text", - ".oga": "audio/ogg", - ".ogg": "audio/ogg", - ".ogv": "video/ogg", - ".omc": "application/x-omc", - ".omcd": "application/x-omcdatamaker", - ".omcr": "application/x-omcregerator", - ".otc": "application/vnd.oasis.opendocument.chart-template", - ".otf": "application/vnd.oasis.opendocument.formula-template", - ".otg": "application/vnd.oasis.opendocument.graphics-template", - ".oth": "application/vnd.oasis.opendocument.text-web", - ".oti": "application/vnd.oasis.opendocument.image-template", - ".otm": "application/vnd.oasis.opendocument.text-master", - ".otp": "application/vnd.oasis.opendocument.presentation-template", - ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", - ".ott": "application/vnd.oasis.opendocument.text-template", - ".p10": "application/pkcs10", - ".p12": "application/pkcs-12", - ".p7a": "application/x-pkcs7-signature", - ".p7c": "application/pkcs7-mime", - ".p7m": "application/pkcs7-mime", - ".p7r": "application/x-pkcs7-certreqresp", - ".p7s": "application/pkcs7-signature", - ".p": "text/x-pascal", - ".part": "application/pro_eng", - ".pas": "text/pascal", - ".pbm": "image/x-portable-bitmap", - ".pcl": "application/vndhp-pcl", - ".pct": "image/x-pict", - ".pcx": "image/x-pcx", - ".pdb": "chemical/x-pdb", - ".pdf": "application/pdf", - ".pfunk": "audio/make", - ".pgm": "image/x-portable-graymap", - ".pic": "image/pict", - ".pict": "image/pict", - ".pkg": "application/x-newton-compatible-pkg", - ".pko": "application/vndms-pkipko", - ".pl": "text/x-scriptperl", - ".plx": "application/x-pixclscript", - ".pm4": "application/x-pagemaker", - ".pm5": "application/x-pagemaker", - ".pm": "text/x-scriptperl-module", - ".png": "image/png", - ".pnm": "application/x-portable-anymap", - ".pot": "application/mspowerpoint", - ".pov": "model/x-pov", - ".ppa": "application/vndms-powerpoint", - ".ppm": "image/x-portable-pixmap", - ".pps": "application/mspowerpoint", - ".ppt": "application/mspowerpoint", - ".ppz": "application/mspowerpoint", - ".pre": "application/x-freelance", - ".prt": "application/pro_eng", - ".ps": "application/postscript", - ".psd": "application/octet-stream", - ".pvu": "paleovu/x-pv", - ".pwz": "application/vndms-powerpoint", - ".py": "text/x-scriptphyton", - ".pyc": "application/x-bytecodepython", - ".qcp": "audio/vndqcelp", - ".qd3": "x-world/x-3dmf", - ".qd3d": "x-world/x-3dmf", - ".qif": "image/x-quicktime", - ".qt": "video/quicktime", - ".qtc": "video/x-qtc", - ".qti": "image/x-quicktime", - ".qtif": "image/x-quicktime", - ".ra": "audio/x-pn-realaudio", - ".ram": "audio/x-pn-realaudio", - ".rar": "application/x-rar-compressed", - ".ras": "application/x-cmu-raster", - ".rast": "image/cmu-raster", - ".rexx": "text/x-scriptrexx", - ".rf": "image/vndrn-realflash", - ".rgb": "image/x-rgb", - ".rm": "application/vndrn-realmedia", - ".rmi": "audio/mid", - ".rmm": "audio/x-pn-realaudio", - ".rmp": "audio/x-pn-realaudio", - ".rng": "application/ringing-tones", - ".rnx": "application/vndrn-realplayer", - ".roff": "application/x-troff", - ".rp": "image/vndrn-realpix", - ".rpm": "audio/x-pn-realaudio-plugin", - ".rt": "text/vndrn-realtext", - ".rtf": "text/richtext", - ".rtx": "text/richtext", - ".rv": "video/vndrn-realvideo", - ".s": "text/x-asm", - ".s3m": "audio/s3m", - ".s7z": "application/x-7z-compressed", - ".saveme": "application/octet-stream", - ".sbk": "application/x-tbook", - ".scm": "text/x-scriptscheme", - ".sdml": "text/plain", - ".sdp": "application/sdp", - ".sdr": "application/sounder", - ".sea": "application/sea", - ".set": "application/set", - ".sgm": "text/x-sgml", - ".sgml": "text/x-sgml", - ".sh": "text/x-scriptsh", - ".shar": "application/x-bsh", - ".shtml": "text/x-server-parsed-html", - ".sid": "audio/x-psid", - ".skd": "application/x-koan", - ".skm": "application/x-koan", - ".skp": "application/x-koan", - ".skt": "application/x-koan", - ".sit": "application/x-stuffit", - ".sitx": "application/x-stuffitx", - ".sl": "application/x-seelogo", - ".smi": "application/smil", - ".smil": "application/smil", - ".snd": "audio/basic", - ".sol": "application/solids", - ".spc": "text/x-speech", - ".spl": "application/futuresplash", - ".spr": "application/x-sprite", - ".sprite": "application/x-sprite", - ".spx": "audio/ogg", - ".src": "application/x-wais-source", - ".ssi": "text/x-server-parsed-html", - ".ssm": "application/streamingmedia", - ".sst": "application/vndms-pkicertstore", - ".step": "application/step", - ".stl": "application/sla", - ".stp": "application/step", - ".sv4cpio": "application/x-sv4cpio", - ".sv4crc": "application/x-sv4crc", - ".svf": "image/vnddwg", - ".svg": "image/svg+xml", - ".svr": "application/x-world", - ".swf": "application/x-shockwave-flash", - ".t": "application/x-troff", - ".talk": "text/x-speech", - ".tar": "application/x-tar", - ".tbk": "application/toolbook", - ".tcl": "text/x-scripttcl", - ".tcsh": "text/x-scripttcsh", - ".tex": "application/x-tex", - ".texi": "application/x-texinfo", - ".texinfo": "application/x-texinfo", - ".text": "text/plain", - ".tgz": "application/gnutar", - ".tif": "image/tiff", - ".tiff": "image/tiff", - ".tr": "application/x-troff", - ".tsi": "audio/tsp-audio", - ".tsp": "application/dsptype", - ".tsv": "text/tab-separated-values", - ".turbot": "image/florian", - ".txt": "text/plain", - ".uil": "text/x-uil", - ".uni": "text/uri-list", - ".unis": "text/uri-list", - ".unv": "application/i-deas", - ".uri": "text/uri-list", - ".uris": "text/uri-list", - ".ustar": "application/x-ustar", - ".uu": "text/x-uuencode", - ".uue": "text/x-uuencode", - ".vcd": "application/x-cdlink", - ".vcf": "text/x-vcard", - ".vcard": "text/x-vcard", - ".vcs": "text/x-vcalendar", - ".vda": "application/vda", - ".vdo": "video/vdo", - ".vew": "application/groupwise", - ".viv": "video/vivo", - ".vivo": "video/vivo", - ".vmd": "application/vocaltec-media-desc", - ".vmf": "application/vocaltec-media-file", - ".voc": "audio/voc", - ".vos": "video/vosaic", - ".vox": "audio/voxware", - ".vqe": "audio/x-twinvq-plugin", - ".vqf": "audio/x-twinvq", - ".vql": "audio/x-twinvq-plugin", - ".vrml": "application/x-vrml", - ".vrt": "x-world/x-vrt", - ".vsd": "application/x-visio", - ".vst": "application/x-visio", - ".vsw": "application/x-visio", - ".w60": "application/wordperfect60", - ".w61": "application/wordperfect61", - ".w6w": "application/msword", - ".wav": "audio/wav", - ".wb1": "application/x-qpro", - ".wbmp": "image/vnd.wap.wbmp", - ".web": "application/vndxara", - ".wiz": "application/msword", - ".wk1": "application/x-123", - ".wmf": "windows/metafile", - ".wml": "text/vnd.wap.wml", - ".wmlc": "application/vnd.wap.wmlc", - ".wmls": "text/vnd.wap.wmlscript", - ".wmlsc": "application/vnd.wap.wmlscriptc", - ".word": "application/msword", - ".wp5": "application/wordperfect", - ".wp6": "application/wordperfect", - ".wp": "application/wordperfect", - ".wpd": "application/wordperfect", - ".wq1": "application/x-lotus", - ".wri": "application/mswrite", - ".wrl": "application/x-world", - ".wrz": "model/vrml", - ".wsc": "text/scriplet", - ".wsrc": "application/x-wais-source", - ".wtk": "application/x-wintalk", - ".x-png": "image/png", - ".xbm": "image/x-xbitmap", - ".xdr": "video/x-amt-demorun", - ".xgz": "xgl/drawing", - ".xif": "image/vndxiff", - ".xl": "application/excel", - ".xla": "application/excel", - ".xlb": "application/excel", - ".xlc": "application/excel", - ".xld": "application/excel", - ".xlk": "application/excel", - ".xll": "application/excel", - ".xlm": "application/excel", - ".xls": "application/excel", - ".xlt": "application/excel", - ".xlv": "application/excel", - ".xlw": "application/excel", - ".xm": "audio/xm", - ".xml": "text/xml", - ".xmz": "xgl/movie", - ".xpix": "application/x-vndls-xpix", - ".xpm": "image/x-xpixmap", - ".xsr": "video/x-amt-showrun", - ".xwd": "image/x-xwd", - ".xyz": "chemical/x-pdb", - ".z": "application/x-compress", - ".zip": "application/zip", - ".zoo": "application/octet-stream", - ".zsh": "text/x-scriptzsh", - ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ".docm": "application/vnd.ms-word.document.macroEnabled.12", - ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - ".dotm": "application/vnd.ms-word.template.macroEnabled.12", - ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12", - ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template", - ".xltm": "application/vnd.ms-excel.template.macroEnabled.12", - ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12", - ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12", - ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12", - ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow", - ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12", - ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template", - ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12", - ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12", - ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide", - ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12", - ".thmx": "application/vnd.ms-officetheme", - ".onetoc": "application/onenote", - ".onetoc2": "application/onenote", - ".onetmp": "application/onenote", - ".onepkg": "application/onenote", - ".key": "application/x-iwork-keynote-sffkey", - ".kth": "application/x-iwork-keynote-sffkth", - ".nmbtemplate": "application/x-iwork-numbers-sfftemplate", - ".numbers": "application/x-iwork-numbers-sffnumbers", - ".pages": "application/x-iwork-pages-sffpages", - ".template": "application/x-iwork-pages-sfftemplate", - ".xpi": "application/x-xpinstall", - ".oex": "application/x-opera-extension", - ".mustache": "text/html", -} diff --git a/namespace.go b/namespace.go deleted file mode 100644 index a6962994..00000000 --- a/namespace.go +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "strings" - - beecontext "github.com/astaxie/beego/context" -) - -type namespaceCond func(*beecontext.Context) bool - -// LinkNamespace used as link action -// Deprecated: using pkg/, we will delete this in v2.1.0 -type LinkNamespace func(*Namespace) - -// Namespace is store all the info -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Namespace struct { - prefix string - handlers *ControllerRegister -} - -// NewNamespace get new Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { - ns := &Namespace{ - prefix: prefix, - handlers: NewControllerRegister(), - } - for _, p := range params { - p(ns) - } - return ns -} - -// Cond set condition function -// if cond return true can run this namespace, else can't -// usage: -// ns.Cond(func (ctx *context.Context) bool{ -// if ctx.Input.Domain() == "api.beego.me" { -// return true -// } -// return false -// }) -// Cond as the first filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Cond(cond namespaceCond) *Namespace { - fn := func(ctx *beecontext.Context) { - if !cond(ctx) { - exception("405", ctx) - } - } - if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { - mr := new(FilterRouter) - mr.tree = NewTree() - mr.pattern = "*" - mr.filterFunc = fn - mr.tree.AddRouter("*", true) - n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...) - } else { - n.handlers.InsertFilter("*", BeforeRouter, fn) - } - return n -} - -// Filter add filter in the Namespace -// action has before & after -// FilterFunc -// usage: -// Filter("before", func (ctx *context.Context){ -// _, ok := ctx.Input.Session("uid").(int) -// if !ok && ctx.Request.RequestURI != "/login" { -// ctx.Redirect(302, "/login") -// } -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { - var a int - if action == "before" { - a = BeforeRouter - } else if action == "after" { - a = FinishRouter - } - for _, f := range filter { - n.handlers.InsertFilter("*", a, f) - } - return n -} - -// Router same as beego.Rourer -// refer: https://godoc.org/github.com/astaxie/beego#Router -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) - return n -} - -// AutoRouter same as beego.AutoRouter -// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { - n.handlers.AddAuto(c) - return n -} - -// AutoPrefix same as beego.AutoPrefix -// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { - n.handlers.AddAutoPrefix(prefix, c) - return n -} - -// Get same as beego.Get -// refer: https://godoc.org/github.com/astaxie/beego#Get -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { - n.handlers.Get(rootpath, f) - return n -} - -// Post same as beego.Post -// refer: https://godoc.org/github.com/astaxie/beego#Post -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { - n.handlers.Post(rootpath, f) - return n -} - -// Delete same as beego.Delete -// refer: https://godoc.org/github.com/astaxie/beego#Delete -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { - n.handlers.Delete(rootpath, f) - return n -} - -// Put same as beego.Put -// refer: https://godoc.org/github.com/astaxie/beego#Put -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { - n.handlers.Put(rootpath, f) - return n -} - -// Head same as beego.Head -// refer: https://godoc.org/github.com/astaxie/beego#Head -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { - n.handlers.Head(rootpath, f) - return n -} - -// Options same as beego.Options -// refer: https://godoc.org/github.com/astaxie/beego#Options -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { - n.handlers.Options(rootpath, f) - return n -} - -// Patch same as beego.Patch -// refer: https://godoc.org/github.com/astaxie/beego#Patch -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { - n.handlers.Patch(rootpath, f) - return n -} - -// Any same as beego.Any -// refer: https://godoc.org/github.com/astaxie/beego#Any -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { - n.handlers.Any(rootpath, f) - return n -} - -// Handler same as beego.Handler -// refer: https://godoc.org/github.com/astaxie/beego#Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { - n.handlers.Handler(rootpath, h) - return n -} - -// Include add include class -// refer: https://godoc.org/github.com/astaxie/beego#Include -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { - n.handlers.Include(cList...) - return n -} - -// Namespace add nest Namespace -// usage: -//ns := beego.NewNamespace(“/v1”). -//Namespace( -// beego.NewNamespace("/shop"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("shopinfo")) -// }), -// beego.NewNamespace("/order"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("orderinfo")) -// }), -// beego.NewNamespace("/crm"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("crminfo")) -// }), -//) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { - for _, ni := range ns { - for k, v := range ni.handlers.routers { - if _, ok := n.handlers.routers[k]; ok { - addPrefix(v, ni.prefix) - n.handlers.routers[k].AddTree(ni.prefix, v) - } else { - t := NewTree() - t.AddTree(ni.prefix, v) - addPrefix(t, ni.prefix) - n.handlers.routers[k] = t - } - } - if ni.handlers.enableFilter { - for pos, filterList := range ni.handlers.filters { - for _, mr := range filterList { - t := NewTree() - t.AddTree(ni.prefix, mr.tree) - mr.tree = t - n.handlers.insertFilterRouter(pos, mr) - } - } - } - } - return n -} - -// AddNamespace register Namespace into beego.Handler -// support multi Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddNamespace(nl ...*Namespace) { - for _, n := range nl { - for k, v := range n.handlers.routers { - if _, ok := BeeApp.Handlers.routers[k]; ok { - addPrefix(v, n.prefix) - BeeApp.Handlers.routers[k].AddTree(n.prefix, v) - } else { - t := NewTree() - t.AddTree(n.prefix, v) - addPrefix(t, n.prefix) - BeeApp.Handlers.routers[k] = t - } - } - if n.handlers.enableFilter { - for pos, filterList := range n.handlers.filters { - for _, mr := range filterList { - t := NewTree() - t.AddTree(n.prefix, mr.tree) - mr.tree = t - BeeApp.Handlers.insertFilterRouter(pos, mr) - } - } - } - } -} - -func addPrefix(t *Tree, prefix string) { - for _, v := range t.fixrouters { - addPrefix(v, prefix) - } - if t.wildcard != nil { - addPrefix(t.wildcard, prefix) - } - for _, l := range t.leaves { - if c, ok := l.runObject.(*ControllerInfo); ok { - if !strings.HasPrefix(c.pattern, prefix) { - c.pattern = prefix + c.pattern - } - } - } -} - -// NSCond is Namespace Condition -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSCond(cond namespaceCond) LinkNamespace { - return func(ns *Namespace) { - ns.Cond(cond) - } -} - -// NSBefore Namespace BeforeRouter filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSBefore(filterList ...FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Filter("before", filterList...) - } -} - -// NSAfter add Namespace FinishRouter filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAfter(filterList ...FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Filter("after", filterList...) - } -} - -// NSInclude Namespace Include ControllerInterface -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSInclude(cList ...ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.Include(cList...) - } -} - -// NSRouter call Namespace Router -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { - return func(ns *Namespace) { - ns.Router(rootpath, c, mappingMethods...) - } -} - -// NSGet call Namespace Get -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSGet(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Get(rootpath, f) - } -} - -// NSPost call Namespace Post -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPost(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Post(rootpath, f) - } -} - -// NSHead call Namespace Head -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSHead(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Head(rootpath, f) - } -} - -// NSPut call Namespace Put -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPut(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Put(rootpath, f) - } -} - -// NSDelete call Namespace Delete -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSDelete(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Delete(rootpath, f) - } -} - -// NSAny call Namespace Any -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAny(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Any(rootpath, f) - } -} - -// NSOptions call Namespace Options -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSOptions(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Options(rootpath, f) - } -} - -// NSPatch call Namespace Patch -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPatch(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Patch(rootpath, f) - } -} - -// NSAutoRouter call Namespace AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAutoRouter(c ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.AutoRouter(c) - } -} - -// NSAutoPrefix call Namespace AutoPrefix -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.AutoPrefix(prefix, c) - } -} - -// NSNamespace add sub Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { - return func(ns *Namespace) { - n := NewNamespace(prefix, params...) - ns.Namespace(n) - } -} - -// NSHandler add handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSHandler(rootpath string, h http.Handler) LinkNamespace { - return func(ns *Namespace) { - ns.Handler(rootpath, h) - } -} diff --git a/namespace_test.go b/namespace_test.go deleted file mode 100644 index b3f20dff..00000000 --- a/namespace_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/astaxie/beego/context" -) - -func TestNamespaceGet(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Get("/user", func(ctx *context.Context) { - ctx.Output.Body([]byte("v1_user")) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "v1_user" { - t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespacePost(t *testing.T) { - r, _ := http.NewRequest("POST", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNest(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order", func(ctx *context.Context) { - ctx.Output.Body([]byte("order")) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "order" { - t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNestParam(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceRouter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/api/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Router("/api/list", &TestController{}, "*:List") - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceAutoFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestNamespaceFilter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Filter("before", func(ctx *context.Context) { - ctx.Output.Body([]byte("this is Filter")) - }). - Get("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "this is Filter" { - t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceCond(t *testing.T) { - r, _ := http.NewRequest("GET", "/v2/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v2") - ns.Cond(func(ctx *context.Context) bool { - return ctx.Input.Domain() == "beego.me" - }). - AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Code != 405 { - t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) - } -} - -func TestNamespaceInside(t *testing.T) { - r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) - w := httptest.NewRecorder() - ns := NewNamespace("/v3", - NSAutoRouter(&TestController{}), - NSNamespace("/shop", - NSGet("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) - } -} diff --git a/orm/README.md b/orm/README.md deleted file mode 100644 index 6e808d2a..00000000 --- a/orm/README.md +++ /dev/null @@ -1,159 +0,0 @@ -# beego orm - -[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) - -A powerful orm framework for go. - -It is heavily influenced by Django ORM, SQLAlchemy. - -**Support Database:** - -* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) -* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) -* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - -Passed all test, but need more feedback. - -**Features:** - -* full go type support -* easy for usage, simple CRUD operation -* auto join with relation table -* cross DataBase compatible query -* Raw SQL query / mapper without orm model -* full test keep stable and strong - -more features please read the docs - -**Install:** - - go get github.com/astaxie/beego/orm - -## Changelog - -* 2013-08-19: support table auto create -* 2013-08-13: update test for database types -* 2013-08-13: go type support, such as int8, uint8, byte, rune -* 2013-08-13: date / datetime timezone support very well - -## Quick Start - -#### Simple Usage - -```go -package main - -import ( - "fmt" - "github.com/astaxie/beego/orm" - _ "github.com/go-sql-driver/mysql" // import your used driver -) - -// Model Struct -type User struct { - Id int `orm:"auto"` - Name string `orm:"size(100)"` -} - -func init() { - // register model - orm.RegisterModel(new(User)) - - // set default database - orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) - - // create table - orm.RunSyncdb("default", false, true) -} - -func main() { - o := orm.NewOrm() - - user := User{Name: "slene"} - - // insert - id, err := o.Insert(&user) - - // update - user.Name = "astaxie" - num, err := o.Update(&user) - - // read one - u := User{Id: user.Id} - err = o.Read(&u) - - // delete - num, err = o.Delete(&u) -} -``` - -#### Next with relation - -```go -type Post struct { - Id int `orm:"auto"` - Title string `orm:"size(100)"` - User *User `orm:"rel(fk)"` -} - -var posts []*Post -qs := o.QueryTable("post") -num, err := qs.Filter("User__Name", "slene").All(&posts) -``` - -#### Use Raw sql - -If you don't like ORM,use Raw SQL to query / mapping without ORM setting - -```go -var maps []Params -num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) -if num > 0 { - fmt.Println(maps[0]["id"]) -} -``` - -#### Transaction - -```go -o.Begin() -... -user := User{Name: "slene"} -id, err := o.Insert(&user) -if err == nil { - o.Commit() -} else { - o.Rollback() -} - -``` - -#### Debug Log Queries - -In development env, you can simple use - -```go -func main() { - orm.Debug = true -... -``` - -enable log queries. - -output include all queries, such as exec / prepare / transaction. - -like this: - -```go -[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` -... -``` - -note: not recommend use this in product env. - -## Docs - -more details and examples in docs and test - -[documents](http://beego.me/docs/mvc/model/overview.md) - diff --git a/orm/cmd.go b/orm/cmd.go deleted file mode 100644 index 0ff4dc40..00000000 --- a/orm/cmd.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "flag" - "fmt" - "os" - "strings" -) - -type commander interface { - Parse([]string) - Run() error -} - -var ( - commands = make(map[string]commander) -) - -// print help. -func printHelp(errs ...string) { - content := `orm command usage: - - syncdb - auto create tables - sqlall - print sql of create tables - help - print this help -` - - if len(errs) > 0 { - fmt.Println(errs[0]) - } - fmt.Println(content) - os.Exit(2) -} - -// RunCommand listen for orm command and then run it if command arguments passed. -func RunCommand() { - if len(os.Args) < 2 || os.Args[1] != "orm" { - return - } - - BootStrap() - - args := argString(os.Args[2:]) - name := args.Get(0) - - if name == "help" { - printHelp() - } - - if cmd, ok := commands[name]; ok { - cmd.Parse(os.Args[3:]) - cmd.Run() - os.Exit(0) - } else { - if name == "" { - printHelp() - } else { - printHelp(fmt.Sprintf("unknown command %s", name)) - } - } -} - -// sync database struct command interface. -type commandSyncDb struct { - al *alias - force bool - verbose bool - noInfo bool - rtOnError bool -} - -// parse orm command line arguments. -func (d *commandSyncDb) Parse(args []string) { - var name string - - flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) - flagSet.StringVar(&name, "db", "default", "DataBase alias name") - flagSet.BoolVar(&d.force, "force", false, "drop tables before create") - flagSet.BoolVar(&d.verbose, "v", false, "verbose info") - flagSet.Parse(args) - - d.al = getDbAlias(name) -} - -// run orm line command. -func (d *commandSyncDb) Run() error { - var drops []string - if d.force { - drops = getDbDropSQL(d.al) - } - - db := d.al.DB - - if d.force { - for i, mi := range modelCache.allOrdered() { - query := drops[i] - if !d.noInfo { - fmt.Printf("drop table `%s`\n", mi.table) - } - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - } - - sqls, indexes := getDbCreateSQL(d.al) - - tables, err := d.al.DbBaser.GetTables(db) - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - - for i, mi := range modelCache.allOrdered() { - if tables[mi.table] { - if !d.noInfo { - fmt.Printf("table `%s` already exists, skip\n", mi.table) - } - - var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - - for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; !ok { - fields = append(fields, fi) - } - } - - for _, fi := range fields { - query := getColumnAddQuery(d.al, fi) - - if !d.noInfo { - fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) - } - - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - - for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { - if !d.noInfo { - fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) - } - - query := idx.SQL - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - } - - continue - } - - if !d.noInfo { - fmt.Printf("create table `%s` \n", mi.table) - } - - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { - queries = append(queries, idx.SQL) - } - - for _, query := range queries { - _, err := db.Exec(query) - if d.verbose { - query = " " + strings.Join(strings.Split(query, "\n"), "\n ") - fmt.Println(query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - if d.verbose { - fmt.Println("") - } - } - - return nil -} - -// database creation commander interface implement. -type commandSQLAll struct { - al *alias -} - -// parse orm command line arguments. -func (d *commandSQLAll) Parse(args []string) { - var name string - - flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) - flagSet.StringVar(&name, "db", "default", "DataBase alias name") - flagSet.Parse(args) - - d.al = getDbAlias(name) -} - -// run orm line command. -func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) - var all []string - for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { - queries = append(queries, idx.SQL) - } - sql := strings.Join(queries, "\n") - all = append(all, sql) - } - fmt.Println(strings.Join(all, "\n\n")) - - return nil -} - -func init() { - commands["syncdb"] = new(commandSyncDb) - commands["sqlall"] = new(commandSQLAll) -} - -// RunSyncdb run syncdb command line. -// name means table's alias name. default is "default". -// force means run next sql if the current is error. -// verbose means show all info when running command or not. -func RunSyncdb(name string, force bool, verbose bool) error { - BootStrap() - - al := getDbAlias(name) - cmd := new(commandSyncDb) - cmd.al = al - cmd.force = force - cmd.noInfo = !verbose - cmd.verbose = verbose - cmd.rtOnError = true - return cmd.Run() -} diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go deleted file mode 100644 index 692a079f..00000000 --- a/orm/cmd_utils.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "strings" -) - -type dbIndex struct { - Table string - Name string - SQL string -} - -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - -// get database column type string. -func getColumnTyp(al *alias, fi *fieldInfo) (col string) { - T := al.DbBaser.DbTypes() - fieldType := fi.fieldType - fieldSize := fi.size - -checkColumn: - switch fieldType { - case TypeBooleanField: - col = T["bool"] - case TypeVarCharField: - if al.Driver == DRPostgres && fi.toText { - col = T["string-text"] - } else { - col = fmt.Sprintf(T["string"], fieldSize) - } - case TypeCharField: - col = fmt.Sprintf(T["string-char"], fieldSize) - case TypeTextField: - col = T["string-text"] - case TypeTimeField: - col = T["time.Time-clock"] - case TypeDateField: - col = T["time.Time-date"] - case TypeDateTimeField: - col = T["time.Time"] - case TypeBitField: - col = T["int8"] - case TypeSmallIntegerField: - col = T["int16"] - case TypeIntegerField: - col = T["int32"] - case TypeBigIntegerField: - if al.Driver == DRSqlite { - fieldType = TypeIntegerField - goto checkColumn - } - col = T["int64"] - case TypePositiveBitField: - col = T["uint8"] - case TypePositiveSmallIntegerField: - col = T["uint16"] - case TypePositiveIntegerField: - col = T["uint32"] - case TypePositiveBigIntegerField: - col = T["uint64"] - case TypeFloatField: - col = T["float64"] - case TypeDecimalField: - s := T["float64-decimal"] - if !strings.Contains(s, "%d") { - col = s - } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) - } - case TypeJSONField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["json"] - case TypeJsonbField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["jsonb"] - case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - fieldSize = fi.relModelInfo.fields.pk.size - goto checkColumn - } - - return -} - -// create alter sql string. -func getColumnAddQuery(al *alias, fi *fieldInfo) string { - Q := al.DbBaser.TableQuote() - typ := getColumnTyp(al, fi) - - if !fi.null { - typ += " " + "NOT NULL" - } - - return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, - typ, getColumnDefault(fi), - ) -} - -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - -// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands -func getColumnDefault(fi *fieldInfo) string { - var ( - v, t, d string - ) - - // Skip default attribute if field is in relations - if fi.rel || fi.reverse { - return v - } - - t = " DEFAULT '%s' " - - // These defaults will be useful if there no config value orm:"default" and NOT NULL is on - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: - return v - - case TypeBitField, TypeSmallIntegerField, TypeIntegerField, - TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, - TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, - TypeDecimalField: - t = " DEFAULT %s " - d = "0" - case TypeBooleanField: - t = " DEFAULT %s " - d = "FALSE" - case TypeJSONField, TypeJsonbField: - d = "{}" - } - - if fi.colDefault { - if !fi.initial.Exist() { - v = fmt.Sprintf(t, "") - } else { - v = fmt.Sprintf(t, fi.initial.String()) - } - } else { - if !fi.null { - v = fmt.Sprintf(t, d) - } - } - - return v -} diff --git a/orm/db.go b/orm/db.go deleted file mode 100644 index 5d175bf1..00000000 --- a/orm/db.go +++ /dev/null @@ -1,1908 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" -) - -var ( - // ErrMissPK missing pk error - ErrMissPK = errors.New("missed pk value") -) - -var ( - operators = map[string]bool{ - "exact": true, - "iexact": true, - "contains": true, - "icontains": true, - // "regex": true, - // "iregex": true, - "gt": true, - "gte": true, - "lt": true, - "lte": true, - "eq": true, - "nq": true, - "ne": true, - ">": true, - ">=": true, - "<": true, - "<=": true, - "=": true, - "!=": true, - "startswith": true, - "endswith": true, - "istartswith": true, - "iendswith": true, - "in": true, - "between": true, - // "year": true, - // "month": true, - // "day": true, - // "week_day": true, - "isnull": true, - // "search": true, - } -) - -// an instance of dbBaser interface/ -type dbBase struct { - ins dbBaser -} - -// check dbBase implements dbBaser interface. -var _ dbBaser = new(dbBase) - -// get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { - if names == nil { - ns := make([]string, 0, len(cols)) - names = &ns - } - values = make([]interface{}, 0, len(cols)) - - for _, column := range cols { - var fi *fieldInfo - if fi, _ = mi.fields.GetByAny(column); fi != nil { - column = fi.column - } else { - panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) - } - if !fi.dbcol || fi.auto && skipAuto { - continue - } - value, err := d.collectFieldValue(mi, fi, ind, insert, tz) - if err != nil { - return nil, nil, err - } - - // ignore empty value auto field - if insert && fi.auto { - if fi.fieldType&IsPositiveIntegerField > 0 { - if vu, ok := value.(uint64); !ok || vu == 0 { - continue - } - } else { - if vu, ok := value.(int64); !ok || vu == 0 { - continue - } - } - autoFields = append(autoFields, fi.column) - } - - *names, values = append(*names, column), append(values, value) - } - - return -} - -// get one field value in struct column as interface. -func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { - var value interface{} - if fi.pk { - _, value, _ = getExistPk(mi, ind) - } else { - field := ind.FieldByIndex(fi.fieldIndex) - if fi.isFielder { - f := field.Addr().Interface().(Fielder) - value = f.RawValue() - } else { - switch fi.fieldType { - case TypeBooleanField: - if nb, ok := field.Interface().(sql.NullBool); ok { - value = nil - if nb.Valid { - value = nb.Bool - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Bool() - } - } else { - value = field.Bool() - } - case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: - if ns, ok := field.Interface().(sql.NullString); ok { - value = nil - if ns.Valid { - value = ns.String - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().String() - } - } else { - value = field.String() - } - case TypeFloatField, TypeDecimalField: - if nf, ok := field.Interface().(sql.NullFloat64); ok { - value = nil - if nf.Valid { - value = nf.Float64 - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Float() - } - } else { - vu := field.Interface() - if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() - } else { - value = field.Float() - } - } - case TypeTimeField, TypeDateField, TypeDateTimeField: - value = field.Interface() - if t, ok := value.(time.Time); ok { - d.ins.TimeToDB(&t, tz) - if t.IsZero() { - value = nil - } else { - value = t - } - } - default: - switch { - case fi.fieldType&IsPositiveIntegerField > 0: - if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Uint() - } - } else { - value = field.Uint() - } - case fi.fieldType&IsIntegerField > 0: - if ni, ok := field.Interface().(sql.NullInt64); ok { - value = nil - if ni.Valid { - value = ni.Int64 - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Int() - } - } else { - value = field.Int() - } - case fi.fieldType&IsRelField > 0: - if field.IsNil() { - value = nil - } else { - if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { - value = vu - } else { - value = nil - } - } - if !fi.null && value == nil { - return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) - } - } - } - } - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField: - if fi.autoNow || fi.autoNowAdd && insert { - if insert { - if t, ok := value.(time.Time); ok && !t.IsZero() { - break - } - } - tnow := time.Now() - d.ins.TimeToDB(&tnow, tz) - value = tnow - if fi.isFielder { - f := field.Addr().Interface().(Fielder) - f.SetRaw(tnow.In(DefaultTimeLoc)) - } else if field.Kind() == reflect.Ptr { - v := tnow.In(DefaultTimeLoc) - field.Set(reflect.ValueOf(&v)) - } else { - field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) - } - } - case TypeJSONField, TypeJsonbField: - if s, ok := value.(string); (ok && len(s) == 0) || value == nil { - if fi.colDefault && fi.initial.Exist() { - value = fi.initial.String() - } else { - value = nil - } - } - } - } - return value, nil -} - -// create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { - Q := d.ins.TableQuote() - - dbcols := make([]string, 0, len(mi.fields.dbcols)) - marks := make([]string, 0, len(mi.fields.dbcols)) - for _, fi := range mi.fields.fieldsDB { - if !fi.auto { - dbcols = append(dbcols, fi.column) - marks = append(marks, "?") - } - } - qmarks := strings.Join(marks, ", ") - sep := fmt.Sprintf("%s, %s", Q, Q) - columns := strings.Join(dbcols, sep) - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - d.ins.HasReturningID(mi, &query) - - stmt, err := q.Prepare(query) - return stmt, query, err -} - -// insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) - if err != nil { - return 0, err - } - - if d.ins.HasReturningID(mi, nil) { - row := stmt.QueryRow(values...) - var id int64 - err := row.Scan(&id) - return id, err - } - res, err := stmt.Exec(values...) - if err == nil { - return res.LastInsertId() - } - return 0, err -} - -// query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { - var whereCols []string - var args []interface{} - - // if specify cols length > 0, then use it for where condition. - if len(cols) > 0 { - var err error - whereCols = make([]string, 0, len(cols)) - args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) - if err != nil { - return err - } - } else { - // default use pk value as where condtion. - pkColumn, pkValue, ok := getExistPk(mi, ind) - if !ok { - return ErrMissPK - } - whereCols = []string{pkColumn} - args = append(args, pkValue) - } - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s, %s", Q, Q) - sels := strings.Join(mi.fields.dbcols, sep) - colsNum := len(mi.fields.dbcols) - - sep = fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - forUpdate := "" - if isForUpdate { - forUpdate = "FOR UPDATE" - } - - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) - - refs := make([]interface{}, colsNum) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - d.ins.ReplaceMarks(&query) - - row := q.QueryRow(query, args...) - if err := row.Scan(refs...); err != nil { - if err == sql.ErrNoRows { - return ErrNoRows - } - return err - } - elm := reflect.New(mi.addrField.Elem().Type()) - mind := reflect.Indirect(elm) - d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) - ind.Set(mind) - return nil -} - -// execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)) - values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) - if err != nil { - return 0, err - } - - id, err := d.InsertValue(q, mi, false, names, values) - if err != nil { - return 0, err - } - - if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) - } - return id, err -} - -// multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { - var ( - cnt int64 - nums int - values []interface{} - names []string - ) - - // typ := reflect.Indirect(mi.addrField).Type() - - length, autoFields := sind.Len(), make([]string, 0, 1) - - for i := 1; i <= length; i++ { - - ind := reflect.Indirect(sind.Index(i - 1)) - - // Is this needed ? - // if !ind.Type().AssignableTo(typ) { - // return cnt, ErrArgs - // } - - if i == 1 { - var ( - vus []interface{} - err error - ) - vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) - if err != nil { - return cnt, err - } - values = make([]interface{}, bulk*len(vus)) - nums += copy(values, vus) - } else { - vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) - if err != nil { - return cnt, err - } - - if len(vus) != len(names) { - return cnt, ErrArgs - } - - nums += copy(values[nums:], vus) - } - - if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) - if err != nil { - return cnt, err - } - cnt += num - nums = 0 - } - } - - var err error - if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) - } - - return cnt, err -} - -// execute insert sql with given struct and given values. -// insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { - Q := d.ins.TableQuote() - - marks := make([]string, len(names)) - for i := range marks { - marks[i] = "?" - } - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti && multi > 1 { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err -} - -// InsertOrUpdate a row -// If your primary key or unique column conflict will update -// If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - args0 := "" - iouStr := "" - argsMap := map[string]string{} - switch a.Driver { - case DRMySQL: - iouStr = "ON DUPLICATE KEY UPDATE" - case DRPostgres: - if len(args) == 0 { - return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) - } - args0 = strings.ToLower(args[0]) - iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) - default: - return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) - } - - //Get on the key-value pairs - for _, v := range args { - kv := strings.Split(v, "=") - if len(kv) == 2 { - argsMap[strings.ToLower(kv[0])] = kv[1] - } - } - - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) - Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - - if err != nil { - return 0, err - } - - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) - var conflitValue interface{} - for i, v := range names { - // identifier in database may not be case-sensitive, so quote it - v = fmt.Sprintf("%s%s%s", Q, v, Q) - marks[i] = "?" - valueStr := argsMap[strings.ToLower(v)] - if v == args0 { - conflitValue = values[i] - } - if valueStr != "" { - switch a.Driver { - case DRMySQL: - updates[i] = v + "=" + valueStr - case DRPostgres: - if conflitValue != nil { - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) - updateValues = append(updateValues, conflitValue) - } else { - return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) - } - } - } else { - updates[i] = v + "=?" - updateValues = append(updateValues, values[i]) - } - } - - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - - row := q.QueryRow(query, values...) - var id int64 - err = row.Scan(&id) - if err != nil && err.Error() == `pq: syntax error at or near "ON"` { - err = fmt.Errorf("postgres version must 9.5 or higher") - } - return id, err -} - -// execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { - pkName, pkValue, ok := getExistPk(mi, ind) - if !ok { - return 0, ErrMissPK - } - - var setNames []string - - // if specify cols length is zero, then commit all columns. - if len(cols) == 0 { - cols = mi.fields.dbcols - setNames = make([]string, 0, len(mi.fields.dbcols)-1) - } else { - setNames = make([]string, 0, len(cols)) - } - - setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) - if err != nil { - return 0, err - } - - var findAutoNowAdd, findAutoNow bool - var index int - for i, col := range setNames { - if mi.fields.GetByColumn(col).autoNowAdd { - index = i - findAutoNowAdd = true - } - if mi.fields.GetByColumn(col).autoNow { - findAutoNow = true - } - } - if findAutoNowAdd { - setNames = append(setNames[0:index], setNames[index+1:]...) - setValues = append(setValues[0:index], setValues[index+1:]...) - } - - if !findAutoNow { - for col, info := range mi.fields.columns { - if info.autoNow { - setNames = append(setNames, col) - setValues = append(setValues, time.Now()) - } - } - } - - setValues = append(setValues, pkValue) - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ?, %s", Q, Q) - setColumns := strings.Join(setNames, sep) - - query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) - - d.ins.ReplaceMarks(&query) - - res, err := q.Exec(query, setValues...) - if err == nil { - return res.RowsAffected() - } - return 0, err -} - -// execute delete sql dbQuerier with given struct reflect.Value. -// delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { - var whereCols []string - var args []interface{} - // if specify cols length > 0, then use it for where condition. - if len(cols) > 0 { - var err error - whereCols = make([]string, 0, len(cols)) - args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) - if err != nil { - return 0, err - } - } else { - // default use pk value as where condtion. - pkColumn, pkValue, ok := getExistPk(mi, ind) - if !ok { - return 0, ErrMissPK - } - whereCols = []string{pkColumn} - args = append(args, pkValue) - } - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) - - d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) - if err == nil { - num, err := res.RowsAffected() - if err != nil { - return 0, err - } - if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) - } - } - err := d.deleteRels(q, mi, args, tz) - if err != nil { - return num, err - } - } - return num, err - } - return 0, err -} - -// update table-related record by querySet. -// need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { - columns := make([]string, 0, len(params)) - values := make([]interface{}, 0, len(params)) - for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { - panic(fmt.Errorf("wrong field/column name `%s`", col)) - } else { - columns = append(columns, fi.column) - values = append(values, val) - } - } - - if len(columns) == 0 { - panic(fmt.Errorf("update params cannot empty")) - } - - tables := newDbTables(mi, d.ins) - if qs != nil { - tables.parseRelated(qs.related, qs.relDepth) - } - - where, args := tables.getCondSQL(cond, false, tz) - - values = append(values, args...) - - join := tables.getJoinSQL() - - var query, T string - - Q := d.ins.TableQuote() - - if d.ins.SupportUpdateJoin() { - T = "T0." - } - - cols := make([]string, 0, len(columns)) - - for i, v := range columns { - col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) - if c, ok := values[i].(colValue); ok { - switch c.opt { - case ColAdd: - cols = append(cols, col+" = "+col+" + ?") - case ColMinus: - cols = append(cols, col+" = "+col+" - ?") - case ColMultiply: - cols = append(cols, col+" = "+col+" * ?") - case ColExcept: - cols = append(cols, col+" = "+col+" / ?") - case ColBitAnd: - cols = append(cols, col+" = "+col+" & ?") - case ColBitRShift: - cols = append(cols, col+" = "+col+" >> ?") - case ColBitLShift: - cols = append(cols, col+" = "+col+" << ?") - case ColBitXOR: - cols = append(cols, col+" = "+col+" ^ ?") - case ColBitOr: - cols = append(cols, col+" = "+col+" | ?") - } - values[i] = c.value - } else { - cols = append(cols, col+" = ?") - } - } - - sets := strings.Join(cols, ", ") + " " - - if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) - } else { - supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) - query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) - } - - d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } - if err == nil { - return res.RowsAffected() - } - return 0, err -} - -// delete related records. -// do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { - for _, fi := range mi.fields.fieldsReverse { - fi = fi.reverseFieldInfo - switch fi.onDelete { - case odCascade: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) - if err != nil { - return err - } - case odSetDefault, odSetNULL: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - params := Params{fi.column: nil} - if fi.onDelete == odSetDefault { - params[fi.column] = fi.initial.String() - } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) - if err != nil { - return err - } - case odDoNothing: - } - } - return nil -} - -// delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { - tables := newDbTables(mi, d.ins) - tables.skipEnd = true - - if qs != nil { - tables.parseRelated(qs.related, qs.relDepth) - } - - if cond == nil || cond.IsEmpty() { - panic(fmt.Errorf("delete operation cannot execute without condition")) - } - - Q := d.ins.TableQuote() - - where, args := tables.getCondSQL(cond, false, tz) - join := tables.getJoinSQL() - - cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) - - d.ins.ReplaceMarks(&query) - - var rs *sql.Rows - r, err := q.Query(query, args...) - if err != nil { - return 0, err - } - rs = r - defer rs.Close() - - var ref interface{} - args = make([]interface{}, 0) - cnt := 0 - for rs.Next() { - if err := rs.Scan(&ref); err != nil { - return 0, err - } - pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) - if err != nil { - return 0, err - } - args = append(args, pkValue) - cnt++ - } - - if cnt == 0 { - return 0, nil - } - - marks := make([]string, len(args)) - for i := range marks { - marks[i] = "?" - } - sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) - - d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } - if err == nil { - num, err := res.RowsAffected() - if err != nil { - return 0, err - } - if num > 0 { - err := d.deleteRels(q, mi, args, tz) - if err != nil { - return num, err - } - } - return num, nil - } - return 0, err -} - -// read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { - - val := reflect.ValueOf(container) - ind := reflect.Indirect(val) - - errTyp := true - one := true - isPtr := true - - if val.Kind() == reflect.Ptr { - fn := "" - if ind.Kind() == reflect.Slice { - one = false - typ := ind.Type().Elem() - switch typ.Kind() { - case reflect.Ptr: - fn = getFullName(typ.Elem()) - case reflect.Struct: - isPtr = false - fn = getFullName(typ) - } - } else { - fn = getFullName(ind.Type()) - } - errTyp = fn != mi.fullName - } - - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } - } - - rlimit := qs.limit - offset := qs.offset - - Q := d.ins.TableQuote() - - var tCols []string - if len(cols) > 0 { - hasRel := len(qs.related) > 0 || qs.relDepth > 0 - tCols = make([]string, 0, len(cols)) - var maps map[string]bool - if hasRel { - maps = make(map[string]bool) - } - for _, col := range cols { - if fi, ok := mi.fields.GetByAny(col); ok { - tCols = append(tCols, fi.column) - if hasRel { - maps[fi.column] = true - } - } else { - return 0, fmt.Errorf("wrong field/column name `%s`", col) - } - } - if hasRel { - for _, fi := range mi.fields.fieldsDB { - if fi.fieldType&IsRelField > 0 { - if !maps[fi.column] { - tCols = append(tCols, fi.column) - } - } - } - } - } else { - tCols = mi.fields.dbcols - } - - colsNum := len(tCols) - sep := fmt.Sprintf("%s, T0.%s", Q, Q) - sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) - - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, offset, rlimit) - join := tables.getJoinSQL() - - for _, tbl := range tables.tables { - if tbl.sel { - colsNum += len(tbl.mi.fields.dbcols) - sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) - sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) - } - } - - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) - - if qs.forupdate { - query += " FOR UPDATE" - } - - d.ins.ReplaceMarks(&query) - - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } - } - - refs := make([]interface{}, colsNum) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - defer rs.Close() - - slice := ind - - var cnt int64 - for rs.Next() { - if one && cnt == 0 || !one { - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - elm := reflect.New(mi.addrField.Elem().Type()) - mind := reflect.Indirect(elm) - - cacheV := make(map[string]*reflect.Value) - cacheM := make(map[string]*modelInfo) - trefs := refs - - d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) - trefs = refs[len(tCols):] - - for _, tbl := range tables.tables { - // loop selected tables - if tbl.sel { - last := mind - names := "" - mmi := mi - // loop cascade models - for _, name := range tbl.names { - names += name - if val, ok := cacheV[names]; ok { - last = *val - mmi = cacheM[names] - } else { - fi := mmi.fields.GetByName(name) - lastm := mmi - mmi = fi.relModelInfo - field := last - if last.Kind() != reflect.Invalid { - field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) - if field.IsValid() { - d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) - for _, fi := range mmi.fields.fieldsReverse { - if fi.inModel && fi.reverseFieldInfo.mi == lastm { - if fi.reverseFieldInfo != nil { - f := field.FieldByIndex(fi.fieldIndex) - if f.Kind() == reflect.Ptr { - f.Set(last.Addr()) - } - } - } - } - last = field - } - } - cacheV[names] = &field - cacheM[names] = mmi - } - } - trefs = trefs[len(mmi.fields.dbcols):] - } - } - - if one { - ind.Set(mind) - } else { - if cnt == 0 { - // you can use a empty & caped container list - // orm will not replace it - if ind.Len() != 0 { - // if container is not empty - // create a new one - slice = reflect.New(ind.Type()).Elem() - } - } - - if isPtr { - slice = reflect.Append(slice, mind.Addr()) - } else { - slice = reflect.Append(slice, mind) - } - } - } - cnt++ - } - - if !one { - if cnt > 0 { - ind.Set(slice) - } else { - // when a result is empty and container is nil - // to set a empty container - if ind.IsNil() { - ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) - } - } - } - - return cnt, nil -} - -// excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - tables.getOrderSQL(qs.orders) - join := tables.getJoinSQL() - - Q := d.ins.TableQuote() - - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) - - if groupBy != "" { - query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) - } - - d.ins.ReplaceMarks(&query) - - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } - err = row.Scan(&cnt) - return -} - -// generate sql with replacing operator string placeholders and replaced values. -func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { - var sql string - params := getFlatParams(fi, args, tz) - - if len(params) == 0 { - panic(fmt.Errorf("operator `%s` need at least one args", operator)) - } - arg := params[0] - - switch operator { - case "in": - marks := make([]string, len(params)) - for i := range marks { - marks[i] = "?" - } - sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - case "between": - if len(params) != 2 { - panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) - } - sql = "BETWEEN ? AND ?" - default: - if len(params) > 1 { - panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) - } - sql = d.ins.OperatorSQL(operator) - switch operator { - case "exact": - if arg == nil { - params[0] = "IS NULL" - } - case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": - param := strings.Replace(ToStr(arg), `%`, `\%`, -1) - switch operator { - case "iexact": - case "contains", "icontains": - param = fmt.Sprintf("%%%s%%", param) - case "startswith", "istartswith": - param = fmt.Sprintf("%s%%", param) - case "endswith", "iendswith": - param = fmt.Sprintf("%%%s", param) - } - params[0] = param - case "isnull": - if b, ok := arg.(bool); ok { - if b { - sql = "IS NULL" - } else { - sql = "IS NOT NULL" - } - params = nil - } else { - panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) - } - } - } - return sql, params -} - -// gernerate sql string with inner function, such as UPPER(text). -func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { - // default not use -} - -// set values to struct column. -func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { - for i, column := range cols { - val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() - - fi := mi.fields.GetByColumn(column) - - field := ind.FieldByIndex(fi.fieldIndex) - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) - } - - _, err = d.setFieldValue(fi, value, field) - - if err != nil { - panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) - } - } -} - -// convert value from database result to value following in field type. -func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { - if val == nil { - return nil, nil - } - - var value interface{} - var tErr error - - var str *StrTo - switch v := val.(type) { - case []byte: - s := StrTo(string(v)) - str = &s - case string: - s := StrTo(v) - str = &s - } - - fieldType := fi.fieldType - -setValue: - switch { - case fieldType == TypeBooleanField: - if str == nil { - switch v := val.(type) { - case int64: - b := v == 1 - value = b - default: - s := StrTo(ToStr(v)) - str = &s - } - } - if str != nil { - b, err := str.Bool() - if err != nil { - tErr = err - goto end - } - value = b - } - case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: - if str == nil { - value = ToStr(val) - } else { - value = str.String() - } - case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: - if str == nil { - switch t := val.(type) { - case time.Time: - d.ins.TimeFromDB(&t, tz) - value = t - default: - s := StrTo(ToStr(t)) - str = &s - } - } - if str != nil { - s := str.String() - var ( - t time.Time - err error - ) - if len(s) >= 19 { - s = s[:19] - t, err = time.ParseInLocation(formatDateTime, s, tz) - } else if len(s) >= 10 { - if len(s) > 10 { - s = s[:10] - } - t, err = time.ParseInLocation(formatDate, s, tz) - } else if len(s) >= 8 { - if len(s) > 8 { - s = s[:8] - } - t, err = time.ParseInLocation(formatTime, s, tz) - } - t = t.In(DefaultTimeLoc) - - if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { - tErr = err - goto end - } - value = t - } - case fieldType&IsIntegerField > 0: - if str == nil { - s := StrTo(ToStr(val)) - str = &s - } - if str != nil { - var err error - switch fieldType { - case TypeBitField: - _, err = str.Int8() - case TypeSmallIntegerField: - _, err = str.Int16() - case TypeIntegerField: - _, err = str.Int32() - case TypeBigIntegerField: - _, err = str.Int64() - case TypePositiveBitField: - _, err = str.Uint8() - case TypePositiveSmallIntegerField: - _, err = str.Uint16() - case TypePositiveIntegerField: - _, err = str.Uint32() - case TypePositiveBigIntegerField: - _, err = str.Uint64() - } - if err != nil { - tErr = err - goto end - } - if fieldType&IsPositiveIntegerField > 0 { - v, _ := str.Uint64() - value = v - } else { - v, _ := str.Int64() - value = v - } - } - case fieldType == TypeFloatField || fieldType == TypeDecimalField: - if str == nil { - switch v := val.(type) { - case float64: - value = v - default: - s := StrTo(ToStr(v)) - str = &s - } - } - if str != nil { - v, err := str.Float64() - if err != nil { - tErr = err - goto end - } - value = v - } - case fieldType&IsRelField > 0: - fi = fi.relModelInfo.fields.pk - fieldType = fi.fieldType - goto setValue - } - -end: - if tErr != nil { - err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) - return nil, err - } - - return value, nil - -} - -// set one value to struct column field. -func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { - - fieldType := fi.fieldType - isNative := !fi.isFielder - -setValue: - switch { - case fieldType == TypeBooleanField: - if isNative { - if nb, ok := field.Interface().(sql.NullBool); ok { - if value == nil { - nb.Valid = false - } else { - nb.Bool = value.(bool) - nb.Valid = true - } - field.Set(reflect.ValueOf(nb)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(bool) - field.Set(reflect.ValueOf(&v)) - } - } else { - if value == nil { - value = false - } - field.SetBool(value.(bool)) - } - } - case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: - if isNative { - if ns, ok := field.Interface().(sql.NullString); ok { - if value == nil { - ns.Valid = false - } else { - ns.String = value.(string) - ns.Valid = true - } - field.Set(reflect.ValueOf(ns)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(string) - field.Set(reflect.ValueOf(&v)) - } - } else { - if value == nil { - value = "" - } - field.SetString(value.(string)) - } - } - case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: - if isNative { - if value == nil { - value = time.Time{} - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(time.Time) - field.Set(reflect.ValueOf(&v)) - } - } else { - field.Set(reflect.ValueOf(value)) - } - } - case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: - if value != nil { - v := uint8(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := uint16(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - if field.Type() == reflect.TypeOf(new(uint)) { - v := uint(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := uint32(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - } - case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := value.(uint64) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeBitField && field.Kind() == reflect.Ptr: - if value != nil { - v := int8(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := int16(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - if field.Type() == reflect.TypeOf(new(int)) { - v := int(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := int32(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - } - case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := value.(int64) - field.Set(reflect.ValueOf(&v)) - } - case fieldType&IsIntegerField > 0: - if fieldType&IsPositiveIntegerField > 0 { - if isNative { - if value == nil { - value = uint64(0) - } - field.SetUint(value.(uint64)) - } - } else { - if isNative { - if ni, ok := field.Interface().(sql.NullInt64); ok { - if value == nil { - ni.Valid = false - } else { - ni.Int64 = value.(int64) - ni.Valid = true - } - field.Set(reflect.ValueOf(ni)) - } else { - if value == nil { - value = int64(0) - } - field.SetInt(value.(int64)) - } - } - } - case fieldType == TypeFloatField || fieldType == TypeDecimalField: - if isNative { - if nf, ok := field.Interface().(sql.NullFloat64); ok { - if value == nil { - nf.Valid = false - } else { - nf.Float64 = value.(float64) - nf.Valid = true - } - field.Set(reflect.ValueOf(nf)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - if field.Type() == reflect.TypeOf(new(float32)) { - v := float32(value.(float64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := value.(float64) - field.Set(reflect.ValueOf(&v)) - } - } - } else { - - if value == nil { - value = float64(0) - } - field.SetFloat(value.(float64)) - } - } - case fieldType&IsRelField > 0: - if value != nil { - fieldType = fi.relModelInfo.fields.pk.fieldType - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - field = f - goto setValue - } - } - - if !isNative { - fd := field.Addr().Interface().(Fielder) - err := fd.SetRaw(value) - if err != nil { - err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) - return nil, err - } - } - - return value, nil -} - -// query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { - - var ( - maps []Params - lists []ParamsList - list ParamsList - ) - - typ := 0 - switch v := container.(type) { - case *[]Params: - d := *v - if len(d) == 0 { - maps = d - } - typ = 1 - case *[]ParamsList: - d := *v - if len(d) == 0 { - lists = d - } - typ = 2 - case *ParamsList: - d := *v - if len(d) == 0 { - list = d - } - typ = 3 - default: - panic(fmt.Errorf("unsupport read values type `%T`", container)) - } - - tables := newDbTables(mi, d.ins) - - var ( - cols []string - infos []*fieldInfo - ) - - hasExprs := len(exprs) > 0 - - Q := d.ins.TableQuote() - - if hasExprs { - cols = make([]string, 0, len(exprs)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, ex := range exprs { - index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", ex)) - } - cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) - infos = append(infos, fi) - } - } else { - cols = make([]string, 0, len(mi.fields.dbcols)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, fi := range mi.fields.fieldsDB { - cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) - infos = append(infos, fi) - } - } - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, qs.offset, qs.limit) - join := tables.getJoinSQL() - - sels := strings.Join(cols, ", ") - - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) - - d.ins.ReplaceMarks(&query) - - rs, err := q.Query(query, args...) - if err != nil { - return 0, err - } - refs := make([]interface{}, len(cols)) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - defer rs.Close() - - var ( - cnt int64 - columns []string - ) - for rs.Next() { - if cnt == 0 { - cols, err := rs.Columns() - if err != nil { - return 0, err - } - columns = cols - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - switch typ { - case 1: - params := make(Params, len(cols)) - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - params[columns[i]] = value - } - maps = append(maps, params) - case 2: - params := make(ParamsList, 0, len(cols)) - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - params = append(params, value) - } - lists = append(lists, params) - case 3: - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - list = append(list, value) - } - } - - cnt++ - } - - switch v := container.(type) { - case *[]Params: - *v = maps - case *[]ParamsList: - *v = lists - case *ParamsList: - *v = list - } - - return cnt, nil -} - -func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { - return 0, nil -} - -// flag of update joined record. -func (d *dbBase) SupportUpdateJoin() bool { - return true -} - -func (d *dbBase) MaxLimit() uint64 { - return 18446744073709551615 -} - -// return quote. -func (d *dbBase) TableQuote() string { - return "`" -} - -// replace value placeholder in parametered sql string. -func (d *dbBase) ReplaceMarks(query *string) { - // default use `?` as mark, do nothing -} - -// flag of RETURNING sql. -func (d *dbBase) HasReturningID(*modelInfo, *string) bool { - return false -} - -// sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { - return nil -} - -// convert time from db. -func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { - *t = t.In(tz) -} - -// convert time to db. -func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { - *t = t.In(tz) -} - -// get database types. -func (d *dbBase) DbTypes() map[string]string { - return nil -} - -// gt all tables. -func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { - tables := make(map[string]bool) - query := d.ins.ShowTablesQuery() - rows, err := db.Query(query) - if err != nil { - return tables, err - } - - defer rows.Close() - - for rows.Next() { - var table string - err := rows.Scan(&table) - if err != nil { - return tables, err - } - if table != "" { - tables[table] = true - } - } - - return tables, nil -} - -// get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { - columns := make(map[string][3]string) - query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) - if err != nil { - return columns, err - } - - defer rows.Close() - - for rows.Next() { - var ( - name string - typ string - null string - ) - err := rows.Scan(&name, &typ, &null) - if err != nil { - return columns, err - } - columns[name] = [3]string{name, typ, null} - } - - return columns, nil -} - -// not implement. -func (d *dbBase) OperatorSQL(operator string) string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) ShowTablesQuery() string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) ShowColumnsQuery(table string) string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { - panic(ErrNotImplement) -} diff --git a/orm/db_alias.go b/orm/db_alias.go deleted file mode 100644 index d3dbc595..00000000 --- a/orm/db_alias.go +++ /dev/null @@ -1,487 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "fmt" - lru "github.com/hashicorp/golang-lru" - "reflect" - "sync" - "time" -) - -// DriverType database driver constant int. -type DriverType int - -// Enum the Database driver -const ( - _ DriverType = iota // int enum type - DRMySQL // mysql - DRSqlite // sqlite - DROracle // oracle - DRPostgres // pgsql - DRTiDB // TiDB -) - -// database driver string. -type driver string - -// get type constant int of current driver.. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d driver) Type() DriverType { - a, _ := dataBaseCache.get(string(d)) - return a.Driver -} - -// get name of current driver -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d driver) Name() string { - return string(d) -} - -// check driver iis implemented Driver interface or not. -var _ Driver = new(driver) - -var ( - dataBaseCache = &_dbCache{cache: make(map[string]*alias)} - drivers = map[string]DriverType{ - "mysql": DRMySQL, - "postgres": DRPostgres, - "sqlite3": DRSqlite, - "tidb": DRTiDB, - "oracle": DROracle, - "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora - } - dbBasers = map[DriverType]dbBaser{ - DRMySQL: newdbBaseMysql(), - DRSqlite: newdbBaseSqlite(), - DROracle: newdbBaseOracle(), - DRPostgres: newdbBasePostgres(), - DRTiDB: newdbBaseTidb(), - } -) - -// database alias cacher. -type _dbCache struct { - mux sync.RWMutex - cache map[string]*alias -} - -// add database alias with original name. -func (ac *_dbCache) add(name string, al *alias) (added bool) { - ac.mux.Lock() - defer ac.mux.Unlock() - if _, ok := ac.cache[name]; !ok { - ac.cache[name] = al - added = true - } - return -} - -// get database alias if cached. -func (ac *_dbCache) get(name string) (al *alias, ok bool) { - ac.mux.RLock() - defer ac.mux.RUnlock() - al, ok = ac.cache[name] - return -} - -// get default alias. -func (ac *_dbCache) getDefault() (al *alias) { - al, _ = ac.get("default") - return -} - -type DB struct { - *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache -} - -// Begin start a transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Begin() (*sql.Tx, error) { - return d.DB.Begin() -} - -// BeginTx start a transaction with context and those options -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, opts) -} - -// su must call release to release *sql.Stmt after using -func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { - d.RLock() - c, ok := d.stmtDecorators.Get(query) - if ok { - c.(*stmtDecorator).acquire() - d.RUnlock() - return c.(*stmtDecorator), nil - } - d.RUnlock() - - d.Lock() - c, ok = d.stmtDecorators.Get(query) - if ok { - c.(*stmtDecorator).acquire() - d.Unlock() - return c.(*stmtDecorator), nil - } - - stmt, err := d.Prepare(query) - if err != nil { - d.Unlock() - return nil, err - } - sd := newStmtDecorator(stmt) - sd.acquire() - d.stmtDecorators.Add(query, sd) - d.Unlock() - - return sd, nil -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Prepare(query string) (*sql.Stmt, error) { - return d.DB.Prepare(query) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - return d.DB.PrepareContext(ctx, query) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.ExecContext(ctx, args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryContext(ctx, args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRowContext(ctx, args) -} - -type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *DB - DbBaser dbBaser - TZ *time.Location - Engine string -} - -func detectTZ(al *alias) { - // orm timezone system match database - // default use Local - al.TZ = DefaultTimeLoc - - if al.DriverName == "sphinx" { - return - } - - switch al.Driver { - case DRMySQL: - row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") - var tz string - row.Scan(&tz) - if len(tz) >= 8 { - if tz[0] != '-' { - tz = "+" + tz - } - t, err := time.Parse("-07:00:00", tz) - if err == nil { - if t.Location().String() != "" { - al.TZ = t.Location() - } - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } - - // get default engine from current database - row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") - var engine string - var tx bool - row.Scan(&engine, &tx) - - if engine != "" { - al.Engine = engine - } else { - al.Engine = "INNODB" - } - - case DRSqlite, DROracle: - al.TZ = time.UTC - - case DRPostgres: - row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") - var tz string - row.Scan(&tz) - loc, err := time.LoadLocation(tz) - if err == nil { - al.TZ = loc - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } -} - -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - -// AddAliasWthDB add a aliasName for the drivename -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err -} - -// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err -} - -// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func RegisterDriver(driverName string, typ DriverType) error { - if t, ok := drivers[driverName]; !ok { - drivers[driverName] = typ - } else { - if t != typ { - return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) - } - } - return nil -} - -// SetDataBaseTZ Change the database default used timezone -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetDataBaseTZ(aliasName string, tz *time.Location) error { - if al, ok := dataBaseCache.get(aliasName); ok { - al.TZ = tz - } else { - return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) - } - return nil -} - -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } -} - -// GetDB Get *sql.DB from registered database by db alias name. -// Use "default" as alias name if you not set. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func GetDB(aliasNames ...string) (*sql.DB, error) { - var name string - if len(aliasNames) > 0 { - name = aliasNames[0] - } else { - name = "default" - } - al, ok := dataBaseCache.get(name) - if ok { - return al.DB.DB, nil - } - return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) -} - -type stmtDecorator struct { - wg sync.WaitGroup - stmt *sql.Stmt -} - -func (s *stmtDecorator) getStmt() *sql.Stmt { - return s.stmt -} - -// acquire will add one -// since this method will be used inside read lock scope, -// so we can not do more things here -// we should think about refactor this -func (s *stmtDecorator) acquire() { - s.wg.Add(1) -} - -func (s *stmtDecorator) release() { - s.wg.Done() -} - -//garbage recycle for stmt -func (s *stmtDecorator) destroy() { - go func() { - s.wg.Wait() - _ = s.stmt.Close() - }() -} - -func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { - return &stmtDecorator{ - stmt: sqlStmt, - } -} - -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { - value.(*stmtDecorator).destroy() - }) - return cache -} diff --git a/orm/db_mysql.go b/orm/db_mysql.go deleted file mode 100644 index 36f6f566..00000000 --- a/orm/db_mysql.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" - "strings" -) - -// mysql operators. -var mysqlOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ?", - "contains": "LIKE BINARY ?", - "icontains": "LIKE ?", - // "regex": "REGEXP BINARY ?", - // "iregex": "REGEXP ?", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE BINARY ?", - "endswith": "LIKE BINARY ?", - "istartswith": "LIKE ?", - "iendswith": "LIKE ?", -} - -// mysql column field types. -var mysqlTypes = map[string]string{ - "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "longtext", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", -} - -// mysql dbBaser implementation. -type dbBaseMysql struct { - dbBase -} - -var _ dbBaser = new(dbBaseMysql) - -// get mysql operator. -func (d *dbBaseMysql) OperatorSQL(operator string) string { - return mysqlOperators[operator] -} - -// get mysql table field types. -func (d *dbBaseMysql) DbTypes() map[string]string { - return mysqlTypes -} - -// show table sql for mysql. -func (d *dbBaseMysql) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" -} - -// show columns sql of table for mysql. -func (d *dbBaseMysql) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ - "WHERE table_schema = DATABASE() AND table_name = '%s'", table) -} - -// execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ - "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// InsertOrUpdate a row -// If your primary key or unique column conflict will update -// If no will insert -// Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - var iouStr string - argsMap := map[string]string{} - - iouStr = "ON DUPLICATE KEY UPDATE" - - //Get on the key-value pairs - for _, v := range args { - kv := strings.Split(v, "=") - if len(kv) == 2 { - argsMap[strings.ToLower(kv[0])] = kv[1] - } - } - - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) - Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - - if err != nil { - return 0, err - } - - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) - - for i, v := range names { - marks[i] = "?" - valueStr := argsMap[strings.ToLower(v)] - if valueStr != "" { - updates[i] = "`" + v + "`" + "=" + valueStr - } else { - updates[i] = "`" + v + "`" + "=?" - updateValues = append(updateValues, values[i]) - } - } - - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - - row := q.QueryRow(query, values...) - var id int64 - err = row.Scan(&id) - return id, err -} - -// create new mysql dbBaser. -func newdbBaseMysql() dbBaser { - b := new(dbBaseMysql) - b.ins = b - return b -} diff --git a/orm/db_oracle.go b/orm/db_oracle.go deleted file mode 100644 index ed2ec74c..00000000 --- a/orm/db_oracle.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" -) - -// oracle operators. -var oracleOperators = map[string]string{ - "exact": "= ?", - "=": "= ?", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "//iendswith": "LIKE ?", -} - -// oracle column field types. -var oracleTypes = map[string]string{ - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "VARCHAR2(%d)", - "string-char": "CHAR(%d)", - "string-text": "VARCHAR2(%d)", - "time.Time-date": "DATE", - "time.Time": "TIMESTAMP", - "int8": "INTEGER", - "int16": "INTEGER", - "int32": "INTEGER", - "int64": "INTEGER", - "uint8": "INTEGER", - "uint16": "INTEGER", - "uint32": "INTEGER", - "uint64": "INTEGER", - "float64": "NUMBER", - "float64-decimal": "NUMBER(%d, %d)", -} - -// oracle dbBaser -type dbBaseOracle struct { - dbBase -} - -var _ dbBaser = new(dbBaseOracle) - -// create oracle dbBaser. -func newdbBaseOracle() dbBaser { - b := new(dbBaseOracle) - b.ins = b - return b -} - -// OperatorSQL get oracle operator. -func (d *dbBaseOracle) OperatorSQL(operator string) string { - return oracleOperators[operator] -} - -// DbTypes get oracle table field types. -func (d *dbBaseOracle) DbTypes() map[string]string { - return oracleTypes -} - -//ShowTablesQuery show all the tables in database -func (d *dbBaseOracle) ShowTablesQuery() string { - return "SELECT TABLE_NAME FROM USER_TABLES" -} - -// Oracle -func (d *dbBaseOracle) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ - "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) -} - -// check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ - "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ - "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) - - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// execute insert sql with given struct and given values. -// insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { - Q := d.ins.TableQuote() - - marks := make([]string, len(names)) - for i := range marks { - marks[i] = ":" + names[i] - } - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err -} diff --git a/orm/db_postgres.go b/orm/db_postgres.go deleted file mode 100644 index 7eb88d7a..00000000 --- a/orm/db_postgres.go +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" -) - -// postgresql operators. -var postgresOperators = map[string]string{ - "exact": "= ?", - "iexact": "= UPPER(?)", - "contains": "LIKE ?", - "icontains": "LIKE UPPER(?)", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE ?", - "endswith": "LIKE ?", - "istartswith": "LIKE UPPER(?)", - "iendswith": "LIKE UPPER(?)", -} - -// postgresql column field types. -var postgresTypes = map[string]string{ - "auto": "serial NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "timestamp with time zone", - "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, - "uint16": `integer CHECK("%COL%" >= 0)`, - "uint32": `bigint CHECK("%COL%" >= 0)`, - "uint64": `bigint CHECK("%COL%" >= 0)`, - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", - "json": "json", - "jsonb": "jsonb", -} - -// postgresql dbBaser. -type dbBasePostgres struct { - dbBase -} - -var _ dbBaser = new(dbBasePostgres) - -// get postgresql operator. -func (d *dbBasePostgres) OperatorSQL(operator string) string { - return postgresOperators[operator] -} - -// generate functioned sql string, such as contains(text). -func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - switch operator { - case "contains", "startswith", "endswith": - *leftCol = fmt.Sprintf("%s::text", *leftCol) - case "iexact", "icontains", "istartswith", "iendswith": - *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) - } -} - -// postgresql unsupports updating joined record. -func (d *dbBasePostgres) SupportUpdateJoin() bool { - return false -} - -func (d *dbBasePostgres) MaxLimit() uint64 { - return 0 -} - -// postgresql quote is ". -func (d *dbBasePostgres) TableQuote() string { - return `"` -} - -// postgresql value placeholder is $n. -// replace default ? to $n. -func (d *dbBasePostgres) ReplaceMarks(query *string) { - q := *query - num := 0 - for _, c := range q { - if c == '?' { - num++ - } - } - if num == 0 { - return - } - data := make([]byte, 0, len(q)+num) - num = 1 - for i := 0; i < len(q); i++ { - c := q[i] - if c == '?' { - data = append(data, '$') - data = append(data, []byte(strconv.Itoa(num))...) - num++ - } else { - data = append(data, c) - } - } - *query = string(data) -} - -// make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { - fi := mi.fields.pk - if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { - return false - } - - if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) - } - return true -} - -// sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { - if len(autoFields) == 0 { - return nil - } - - Q := d.ins.TableQuote() - for _, name := range autoFields { - query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", - mi.table, name, - Q, name, Q, - Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { - return err - } - } - return nil -} - -// show table sql for postgresql. -func (d *dbBasePostgres) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" -} - -// show table columns sql for postgresql. -func (d *dbBasePostgres) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) -} - -// get column types of postgresql. -func (d *dbBasePostgres) DbTypes() map[string]string { - return postgresTypes -} - -// check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { - query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// create new postgresql dbBaser. -func newdbBasePostgres() dbBaser { - b := new(dbBasePostgres) - b.ins = b - return b -} diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go deleted file mode 100644 index bd9f5d3b..00000000 --- a/orm/db_sqlite.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "time" -) - -// sqlite operators. -var sqliteOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ? ESCAPE '\\'", - "contains": "LIKE ? ESCAPE '\\'", - "icontains": "LIKE ? ESCAPE '\\'", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE ? ESCAPE '\\'", - "endswith": "LIKE ? ESCAPE '\\'", - "istartswith": "LIKE ? ESCAPE '\\'", - "iendswith": "LIKE ? ESCAPE '\\'", -} - -// sqlite column types. -var sqliteTypes = map[string]string{ - "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "character(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "real", - "float64-decimal": "decimal", -} - -// sqlite dbBaser. -type dbBaseSqlite struct { - dbBase -} - -var _ dbBaser = new(dbBaseSqlite) - -// override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { - if isForUpdate { - DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") - } - return d.dbBase.Read(q, mi, ind, tz, cols, false) -} - -// get sqlite operator. -func (d *dbBaseSqlite) OperatorSQL(operator string) string { - return sqliteOperators[operator] -} - -// generate functioned sql for sqlite. -// only support DATE(text). -func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - if fi.fieldType == TypeDateField { - *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) - } -} - -// unable updating joined record in sqlite. -func (d *dbBaseSqlite) SupportUpdateJoin() bool { - return false -} - -// max int in sqlite. -func (d *dbBaseSqlite) MaxLimit() uint64 { - return 9223372036854775807 -} - -// get column types in sqlite. -func (d *dbBaseSqlite) DbTypes() map[string]string { - return sqliteTypes -} - -// get show tables sql in sqlite. -func (d *dbBaseSqlite) ShowTablesQuery() string { - return "SELECT name FROM sqlite_master WHERE type = 'table'" -} - -// get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { - query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) - if err != nil { - return nil, err - } - - columns := make(map[string][3]string) - for rows.Next() { - var tmp, name, typ, null sql.NullString - err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) - if err != nil { - return nil, err - } - columns[name.String] = [3]string{name.String, typ.String, null.String} - } - - return columns, nil -} - -// get show columns sql in sqlite. -func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { - return fmt.Sprintf("pragma table_info('%s')", table) -} - -// check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { - query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) - if err != nil { - panic(err) - } - defer rows.Close() - for rows.Next() { - var tmp, index sql.NullString - rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) - if name == index.String { - return true - } - } - return false -} - -// create new sqlite dbBaser. -func newdbBaseSqlite() dbBaser { - b := new(dbBaseSqlite) - b.ins = b - return b -} diff --git a/orm/db_tables.go b/orm/db_tables.go deleted file mode 100644 index 4b21a6fc..00000000 --- a/orm/db_tables.go +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" - "time" -) - -// table info struct. -type dbTable struct { - id int - index string - name string - names []string - sel bool - inner bool - mi *modelInfo - fi *fieldInfo - jtl *dbTable -} - -// tables collection struct, contains some tables. -type dbTables struct { - tablesM map[string]*dbTable - tables []*dbTable - mi *modelInfo - base dbBaser - skipEnd bool -} - -// set table info to collection. -// if not exist, create new. -func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { - name := strings.Join(names, ExprSep) - if j, ok := t.tablesM[name]; ok { - j.name = name - j.mi = mi - j.fi = fi - j.inner = inner - } else { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - } - return t.tablesM[name] -} - -// add table info to collection. -func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { - name := strings.Join(names, ExprSep) - if _, ok := t.tablesM[name]; !ok { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - return jt, true - } - return t.tablesM[name], false -} - -// get table info in collection. -func (t *dbTables) get(name string) (*dbTable, bool) { - j, ok := t.tablesM[name] - return j, ok -} - -// get related fields info in recursive depth loop. -// loop once, depth decreases one. -func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { - if depth < 0 || fi.fieldType == RelManyToMany { - return related - } - - if prefix == "" { - prefix = fi.name - } else { - prefix = prefix + ExprSep + fi.name - } - related = append(related, prefix) - - depth-- - for _, fi := range fi.relModelInfo.fields.fieldsRel { - related = t.loopDepth(depth, prefix, fi, related) - } - - return related -} - -// parse related fields. -func (t *dbTables) parseRelated(rels []string, depth int) { - - relsNum := len(rels) - related := make([]string, relsNum) - copy(related, rels) - - relDepth := depth - - if relsNum != 0 { - relDepth = 0 - } - - relDepth-- - for _, fi := range t.mi.fields.fieldsRel { - related = t.loopDepth(relDepth, "", fi, related) - } - - for i, s := range related { - var ( - exs = strings.Split(s, ExprSep) - names = make([]string, 0, len(exs)) - mmi = t.mi - cancel = true - jtl *dbTable - ) - - inner := true - - for _, ex := range exs { - if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { - names = append(names, fi.name) - mmi = fi.relModelInfo - - if fi.null || t.skipEnd { - inner = false - } - - jt := t.set(names, mmi, fi, inner) - jt.jtl = jtl - - if fi.reverse { - cancel = false - } - - if cancel { - jt.sel = depth > 0 - - if i < relsNum { - jt.sel = true - } - } - - jtl = jt - - } else { - panic(fmt.Errorf("unknown model/table name `%s`", ex)) - } - } - } -} - -// generate join string. -func (t *dbTables) getJoinSQL() (join string) { - Q := t.base.TableQuote() - - for _, jt := range t.tables { - if jt.inner { - join += "INNER JOIN " - } else { - join += "LEFT OUTER JOIN " - } - var ( - table string - t1, t2 string - c1, c2 string - ) - t1 = "T0" - if jt.jtl != nil { - t1 = jt.jtl.index - } - t2 = jt.index - table = jt.mi.table - - switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: - c1 = jt.fi.mi.fields.pk.column - for _, ffi := range jt.mi.fields.fieldsRel { - if jt.fi.mi == ffi.relModelInfo { - c2 = ffi.column - break - } - } - default: - c1 = jt.fi.column - c2 = jt.fi.relModelInfo.fields.pk.column - - if jt.fi.reverse { - c1 = jt.mi.fields.pk.column - c2 = jt.fi.reverseFieldInfo.column - } - } - - join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, - t2, Q, c2, Q, t1, Q, c1, Q) - } - return -} - -// parse orm model struct field tag expression. -func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { - var ( - jtl *dbTable - fi *fieldInfo - fiN *fieldInfo - mmi = mi - ) - - num := len(exprs) - 1 - var names []string - - inner := true - -loopFor: - for i, ex := range exprs { - - var ok, okN bool - - if fiN != nil { - fi = fiN - ok = true - fiN = nil - } - - if i == 0 { - fi, ok = mmi.fields.GetByAny(ex) - } - - _ = okN - - if ok { - - isRel := fi.rel || fi.reverse - - names = append(names, fi.name) - - switch { - case fi.rel: - mmi = fi.relModelInfo - if fi.fieldType == RelManyToMany { - mmi = fi.relThroughModelInfo - } - case fi.reverse: - mmi = fi.reverseFieldInfo.mi - } - - if i < num { - fiN, okN = mmi.fields.GetByAny(exprs[i+1]) - } - - if isRel && (!fi.mi.isThrough || num != i) { - if fi.null || t.skipEnd { - inner = false - } - - if t.skipEnd && okN || !t.skipEnd { - if t.skipEnd && okN && fiN.pk { - goto loopEnd - } - - jt, _ := t.add(names, mmi, fi, inner) - jt.jtl = jtl - jtl = jt - } - - } - - if num != i { - continue - } - - loopEnd: - - if i == 0 || jtl == nil { - index = "T0" - } else { - index = jtl.index - } - - info = fi - - if jtl == nil { - name = fi.name - } else { - name = jtl.name + ExprSep + fi.name - } - - switch { - case fi.rel: - - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { - case RelOneToOne, RelForeignKey: - index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name - } - } - - break loopFor - - } else { - index = "" - name = "" - info = nil - success = false - return - } - } - - success = index != "" && info != nil - return -} - -// generate condition sql. -func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { - if cond == nil || cond.IsEmpty() { - return - } - - Q := t.base.TableQuote() - - mi := t.mi - - for i, p := range cond.params { - if i > 0 { - if p.isOr { - where += "OR " - } else { - where += "AND " - } - } - if p.isNot { - where += "NOT " - } - if p.isCond { - w, ps := t.getCondSQL(p.cond, true, tz) - if w != "" { - w = fmt.Sprintf("( %s) ", w) - } - where += w - params = append(params, ps...) - } else { - exprs := p.exprs - - num := len(exprs) - 1 - operator := "" - if operators[exprs[num]] { - operator = exprs[num] - exprs = exprs[:num] - } - - index, _, fi, suc := t.parseExprs(mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) - } - - if operator == "" { - operator = "exact" - } - - var operSQL string - var args []interface{} - if p.isRaw { - operSQL = p.sql - } else { - operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) - } - - leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) - t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) - - where += fmt.Sprintf("%s %s ", leftCol, operSQL) - params = append(params, args...) - - } - } - - if !sub && where != "" { - where = "WHERE " + where - } - - return -} - -// generate group sql. -func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { - if len(groups) == 0 { - return - } - - Q := t.base.TableQuote() - - groupSqls := make([]string, 0, len(groups)) - for _, group := range groups { - exprs := strings.Split(group, ExprSep) - - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } - - groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) - } - - groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) - return -} - -// generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { - if len(orders) == 0 { - return - } - - Q := t.base.TableQuote() - - orderSqls := make([]string, 0, len(orders)) - for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) - - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } - - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) - } - - orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) - return -} - -// generate limit sql. -func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { - if limit == 0 { - limit = int64(DefaultRowsLimit) - } - if limit < 0 { - // no limit - if offset > 0 { - maxLimit := t.base.MaxLimit() - if maxLimit == 0 { - limits = fmt.Sprintf("OFFSET %d", offset) - } else { - limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) - } - } - } else if offset <= 0 { - limits = fmt.Sprintf("LIMIT %d", limit) - } else { - limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) - } - return -} - -// crete new tables collection. -func newDbTables(mi *modelInfo, base dbBaser) *dbTables { - tables := &dbTables{} - tables.tablesM = make(map[string]*dbTable) - tables.mi = mi - tables.base = base - return tables -} diff --git a/orm/db_tidb.go b/orm/db_tidb.go deleted file mode 100644 index 6020a488..00000000 --- a/orm/db_tidb.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2015 TiDB Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" -) - -// mysql dbBaser implementation. -type dbBaseTidb struct { - dbBase -} - -var _ dbBaser = new(dbBaseTidb) - -// get mysql operator. -func (d *dbBaseTidb) OperatorSQL(operator string) string { - return mysqlOperators[operator] -} - -// get mysql table field types. -func (d *dbBaseTidb) DbTypes() map[string]string { - return mysqlTypes -} - -// show table sql for mysql. -func (d *dbBaseTidb) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" -} - -// show columns sql of table for mysql. -func (d *dbBaseTidb) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ - "WHERE table_schema = DATABASE() AND table_name = '%s'", table) -} - -// execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ - "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// create new mysql dbBaser. -func newdbBaseTidb() dbBaser { - b := new(dbBaseTidb) - b.ins = b - return b -} diff --git a/orm/db_utils.go b/orm/db_utils.go deleted file mode 100644 index 7ae10ca5..00000000 --- a/orm/db_utils.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" - "time" -) - -// get table alias. -func getDbAlias(name string) *alias { - if al, ok := dataBaseCache.get(name); ok { - return al - } - panic(fmt.Errorf("unknown DataBase alias name %s", name)) -} - -// get pk column info. -func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { - fi := mi.fields.pk - - v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPositiveIntegerField > 0 { - vu := v.Uint() - exist = vu > 0 - value = vu - } else if fi.fieldType&IsIntegerField > 0 { - vu := v.Int() - exist = true - value = vu - } else if fi.fieldType&IsRelField > 0 { - _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) - } else { - vu := v.String() - exist = vu != "" - value = vu - } - - column = fi.column - return -} - -// get fields description as flatted string. -func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { - -outFor: - for _, arg := range args { - val := reflect.ValueOf(arg) - - if arg == nil { - params = append(params, arg) - continue - } - - kind := val.Kind() - if kind == reflect.Ptr { - val = val.Elem() - kind = val.Kind() - arg = val.Interface() - } - - switch kind { - case reflect.String: - v := val.String() - if fi != nil { - if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { - var t time.Time - var err error - if len(v) >= 19 { - s := v[:19] - t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) - } else if len(v) >= 10 { - s := v - if len(v) > 10 { - s = v[:10] - } - t, err = time.ParseInLocation(formatDate, s, tz) - } else { - s := v - if len(s) > 8 { - s = v[:8] - } - t, err = time.ParseInLocation(formatTime, s, tz) - } - if err == nil { - if fi.fieldType == TypeDateField { - v = t.In(tz).Format(formatDate) - } else if fi.fieldType == TypeDateTimeField { - v = t.In(tz).Format(formatDateTime) - } else { - v = t.In(tz).Format(formatTime) - } - } - } - } - arg = v - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - arg = val.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - arg = val.Uint() - case reflect.Float32: - arg, _ = StrTo(ToStr(arg)).Float64() - case reflect.Float64: - arg = val.Float() - case reflect.Bool: - arg = val.Bool() - case reflect.Slice, reflect.Array: - if _, ok := arg.([]byte); ok { - continue outFor - } - - var args []interface{} - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - - var vu interface{} - if v.CanInterface() { - vu = v.Interface() - } - - if vu == nil { - continue - } - - args = append(args, vu) - } - - if len(args) > 0 { - p := getFlatParams(fi, args, tz) - params = append(params, p...) - } - continue outFor - case reflect.Struct: - if v, ok := arg.(time.Time); ok { - if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(formatDate) - } else if fi != nil && fi.fieldType == TypeDateTimeField { - arg = v.In(tz).Format(formatDateTime) - } else if fi != nil && fi.fieldType == TypeTimeField { - arg = v.In(tz).Format(formatTime) - } else { - arg = v.In(tz).Format(formatDateTime) - } - } else { - typ := val.Type() - name := getFullName(typ) - var value interface{} - if mmi, ok := modelCache.getByFullName(name); ok { - if _, vu, exist := getExistPk(mmi, val); exist { - value = vu - } - } - arg = value - - if arg == nil { - panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) - } - } - } - - params = append(params, arg) - } - return -} diff --git a/orm/models.go b/orm/models.go deleted file mode 100644 index 4776bcba..00000000 --- a/orm/models.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "sync" -) - -const ( - odCascade = "cascade" - odSetNULL = "set_null" - odSetDefault = "set_default" - odDoNothing = "do_nothing" - defaultStructTagName = "orm" - defaultStructTagDelim = ";" -) - -var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } -) - -// model info collection -type _modelCache struct { - sync.RWMutex // only used outsite for bootStrap - orders []string - cache map[string]*modelInfo - cacheByFullName map[string]*modelInfo - done bool -} - -// get all model info -func (mc *_modelCache) all() map[string]*modelInfo { - m := make(map[string]*modelInfo, len(mc.cache)) - for k, v := range mc.cache { - m[k] = v - } - return m -} - -// get ordered model info -func (mc *_modelCache) allOrdered() []*modelInfo { - m := make([]*modelInfo, 0, len(mc.orders)) - for _, table := range mc.orders { - m = append(m, mc.cache[table]) - } - return m -} - -// get model info by table name -func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { - mi, ok = mc.cache[table] - return -} - -// get model info by full name -func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { - mi, ok = mc.cacheByFullName[name] - return -} - -// set model info to collection -func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { - mii := mc.cache[table] - mc.cache[table] = mi - mc.cacheByFullName[mi.fullName] = mi - if mii == nil { - mc.orders = append(mc.orders, table) - } - return mii -} - -// clean all model info. -func (mc *_modelCache) clean() { - mc.orders = make([]string, 0) - mc.cache = make(map[string]*modelInfo) - mc.cacheByFullName = make(map[string]*modelInfo) - mc.done = false -} - -// ResetModelCache Clean model cache. Then you can re-RegisterModel. -// Common use this api for test case. -func ResetModelCache() { - modelCache.clean() -} diff --git a/orm/models_boot.go b/orm/models_boot.go deleted file mode 100644 index 8c56b3c4..00000000 --- a/orm/models_boot.go +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" - "runtime/debug" - "strings" -) - -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - debug.PrintStack() - os.Exit(2) - } -} - -// RegisterModel register models -func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } - RegisterModelWithPrefix("", models...) -} - -// RegisterModelWithPrefix register models with a prefix -func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) - } -} - -// RegisterModelWithSuffix register models with a suffix -func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) - } -} - -// BootStrap bootstrap models. -// make all model parsed and can not add more models -func BootStrap() { - modelCache.Lock() - defer modelCache.Unlock() - if modelCache.done { - return - } - bootStrap() - modelCache.done = true -} diff --git a/orm/models_fields.go b/orm/models_fields.go deleted file mode 100644 index b4fad94f..00000000 --- a/orm/models_fields.go +++ /dev/null @@ -1,783 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "time" -) - -// Define the Type enum -const ( - TypeBooleanField = 1 << iota - TypeVarCharField - TypeCharField - TypeTextField - TypeTimeField - TypeDateField - TypeDateTimeField - TypeBitField - TypeSmallIntegerField - TypeIntegerField - TypeBigIntegerField - TypePositiveBitField - TypePositiveSmallIntegerField - TypePositiveIntegerField - TypePositiveBigIntegerField - TypeFloatField - TypeDecimalField - TypeJSONField - TypeJsonbField - RelForeignKey - RelOneToOne - RelManyToMany - RelReverseOne - RelReverseMany -) - -// Define some logic enum -const ( - IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 - IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 - IsRelField = ^-RelReverseMany >> 18 << 19 - IsFieldType = ^-RelReverseMany<<1 + 1 -) - -// BooleanField A true/false field. -type BooleanField bool - -// Value return the BooleanField -func (e BooleanField) Value() bool { - return bool(e) -} - -// Set will set the BooleanField -func (e *BooleanField) Set(d bool) { - *e = BooleanField(d) -} - -// String format the Bool to string -func (e *BooleanField) String() string { - return strconv.FormatBool(e.Value()) -} - -// FieldType return BooleanField the type -func (e *BooleanField) FieldType() int { - return TypeBooleanField -} - -// SetRaw set the interface to bool -func (e *BooleanField) SetRaw(value interface{}) error { - switch d := value.(type) { - case bool: - e.Set(d) - case string: - v, err := StrTo(d).Bool() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the current value -func (e *BooleanField) RawValue() interface{} { - return e.Value() -} - -// verify the BooleanField implement the Fielder interface -var _ Fielder = new(BooleanField) - -// CharField A string field -// required values tag: size -// The size is enforced at the database level and in models’s validation. -// eg: `orm:"size(120)"` -type CharField string - -// Value return the CharField's Value -func (e CharField) Value() string { - return string(e) -} - -// Set CharField value -func (e *CharField) Set(d string) { - *e = CharField(d) -} - -// String return the CharField -func (e *CharField) String() string { - return e.Value() -} - -// FieldType return the enum type -func (e *CharField) FieldType() int { - return TypeVarCharField -} - -// SetRaw set the interface to string -func (e *CharField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the CharField value -func (e *CharField) RawValue() interface{} { - return e.Value() -} - -// verify CharField implement Fielder -var _ Fielder = new(CharField) - -// TimeField A time, represented in go by a time.Time instance. -// only time values like 10:00:00 -// Has a few extra, optional attr tag: -// -// auto_now: -// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// auto_now_add: -// Automatically set the field to now when the object is first created. Useful for creation of timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type TimeField time.Time - -// Value return the time.Time -func (e TimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the TimeField's value -func (e *TimeField) Set(d time.Time) { - *e = TimeField(d) -} - -// String convert time to string -func (e *TimeField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *TimeField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *TimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return time value -func (e *TimeField) RawValue() interface{} { - return e.Value() -} - -var _ Fielder = new(TimeField) - -// DateField A date, represented in go by a time.Time instance. -// only date values like 2006-01-02 -// Has a few extra, optional attr tag: -// -// auto_now: -// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// auto_now_add: -// Automatically set the field to now when the object is first created. Useful for creation of timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type DateField time.Time - -// Value return the time.Time -func (e DateField) Value() time.Time { - return time.Time(e) -} - -// Set set the DateField's value -func (e *DateField) Set(d time.Time) { - *e = DateField(d) -} - -// String convert datetime to string -func (e *DateField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *DateField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *DateField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDate) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return Date value -func (e *DateField) RawValue() interface{} { - return e.Value() -} - -// verify DateField implement fielder interface -var _ Fielder = new(DateField) - -// DateTimeField A date, represented in go by a time.Time instance. -// datetime values like 2006-01-02 15:04:05 -// Takes the same extra arguments as DateField. -type DateTimeField time.Time - -// Value return the datetime value -func (e DateTimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the time.Time to datetime -func (e *DateTimeField) Set(d time.Time) { - *e = DateTimeField(d) -} - -// String return the time's String -func (e *DateTimeField) String() string { - return e.Value().String() -} - -// FieldType return the enum TypeDateTimeField -func (e *DateTimeField) FieldType() int { - return TypeDateTimeField -} - -// SetRaw convert the string or time.Time to DateTimeField -func (e *DateTimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDateTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the datetime value -func (e *DateTimeField) RawValue() interface{} { - return e.Value() -} - -// verify datetime implement fielder -var _ Fielder = new(DateTimeField) - -// FloatField A floating-point number represented in go by a float32 value. -type FloatField float64 - -// Value return the FloatField value -func (e FloatField) Value() float64 { - return float64(e) -} - -// Set the Float64 -func (e *FloatField) Set(d float64) { - *e = FloatField(d) -} - -// String return the string -func (e *FloatField) String() string { - return ToStr(e.Value(), -1, 32) -} - -// FieldType return the enum type -func (e *FloatField) FieldType() int { - return TypeFloatField -} - -// SetRaw converter interface Float64 float32 or string to FloatField -func (e *FloatField) SetRaw(value interface{}) error { - switch d := value.(type) { - case float32: - e.Set(float64(d)) - case float64: - e.Set(d) - case string: - v, err := StrTo(d).Float64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the FloatField value -func (e *FloatField) RawValue() interface{} { - return e.Value() -} - -// verify FloatField implement Fielder -var _ Fielder = new(FloatField) - -// SmallIntegerField -32768 to 32767 -type SmallIntegerField int16 - -// Value return int16 value -func (e SmallIntegerField) Value() int16 { - return int16(e) -} - -// Set the SmallIntegerField value -func (e *SmallIntegerField) Set(d int16) { - *e = SmallIntegerField(d) -} - -// String convert smallint to string -func (e *SmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type SmallIntegerField -func (e *SmallIntegerField) FieldType() int { - return TypeSmallIntegerField -} - -// SetRaw convert interface int16/string to int16 -func (e *SmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int16: - e.Set(d) - case string: - v, err := StrTo(d).Int16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return smallint value -func (e *SmallIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify SmallIntegerField implement Fielder -var _ Fielder = new(SmallIntegerField) - -// IntegerField -2147483648 to 2147483647 -type IntegerField int32 - -// Value return the int32 -func (e IntegerField) Value() int32 { - return int32(e) -} - -// Set IntegerField value -func (e *IntegerField) Set(d int32) { - *e = IntegerField(d) -} - -// String convert Int32 to string -func (e *IntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return the enum type -func (e *IntegerField) FieldType() int { - return TypeIntegerField -} - -// SetRaw convert interface int32/string to int32 -func (e *IntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int32: - e.Set(d) - case string: - v, err := StrTo(d).Int32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return IntegerField value -func (e *IntegerField) RawValue() interface{} { - return e.Value() -} - -// verify IntegerField implement Fielder -var _ Fielder = new(IntegerField) - -// BigIntegerField -9223372036854775808 to 9223372036854775807. -type BigIntegerField int64 - -// Value return int64 -func (e BigIntegerField) Value() int64 { - return int64(e) -} - -// Set the BigIntegerField value -func (e *BigIntegerField) Set(d int64) { - *e = BigIntegerField(d) -} - -// String convert BigIntegerField to string -func (e *BigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *BigIntegerField) FieldType() int { - return TypeBigIntegerField -} - -// SetRaw convert interface int64/string to int64 -func (e *BigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int64: - e.Set(d) - case string: - v, err := StrTo(d).Int64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return BigIntegerField value -func (e *BigIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify BigIntegerField implement Fielder -var _ Fielder = new(BigIntegerField) - -// PositiveSmallIntegerField 0 to 65535 -type PositiveSmallIntegerField uint16 - -// Value return uint16 -func (e PositiveSmallIntegerField) Value() uint16 { - return uint16(e) -} - -// Set PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) Set(d uint16) { - *e = PositiveSmallIntegerField(d) -} - -// String convert uint16 to string -func (e *PositiveSmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveSmallIntegerField) FieldType() int { - return TypePositiveSmallIntegerField -} - -// SetRaw convert Interface uint16/string to uint16 -func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint16: - e.Set(d) - case string: - v, err := StrTo(d).Uint16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue returns PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveSmallIntegerField implement Fielder -var _ Fielder = new(PositiveSmallIntegerField) - -// PositiveIntegerField 0 to 4294967295 -type PositiveIntegerField uint32 - -// Value return PositiveIntegerField value. Uint32 -func (e PositiveIntegerField) Value() uint32 { - return uint32(e) -} - -// Set the PositiveIntegerField value -func (e *PositiveIntegerField) Set(d uint32) { - *e = PositiveIntegerField(d) -} - -// String convert PositiveIntegerField to string -func (e *PositiveIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint32/string to Uint32 -func (e *PositiveIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint32: - e.Set(d) - case string: - v, err := StrTo(d).Uint32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the PositiveIntegerField Value -func (e *PositiveIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveIntegerField implement Fielder -var _ Fielder = new(PositiveIntegerField) - -// PositiveBigIntegerField 0 to 18446744073709551615 -type PositiveBigIntegerField uint64 - -// Value return uint64 -func (e PositiveBigIntegerField) Value() uint64 { - return uint64(e) -} - -// Set PositiveBigIntegerField value -func (e *PositiveBigIntegerField) Set(d uint64) { - *e = PositiveBigIntegerField(d) -} - -// String convert PositiveBigIntegerField to string -func (e *PositiveBigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveBigIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint64/string to Uint64 -func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint64: - e.Set(d) - case string: - v, err := StrTo(d).Uint64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return PositiveBigIntegerField value -func (e *PositiveBigIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveBigIntegerField implement Fielder -var _ Fielder = new(PositiveBigIntegerField) - -// TextField A large text field. -type TextField string - -// Value return TextField value -func (e TextField) Value() string { - return string(e) -} - -// Set the TextField value -func (e *TextField) Set(d string) { - *e = TextField(d) -} - -// String convert TextField to string -func (e *TextField) String() string { - return e.Value() -} - -// FieldType return enum type -func (e *TextField) FieldType() int { - return TypeTextField -} - -// SetRaw convert interface string to string -func (e *TextField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return TextField value -func (e *TextField) RawValue() interface{} { - return e.Value() -} - -// verify TextField implement Fielder -var _ Fielder = new(TextField) - -// JSONField postgres json field. -type JSONField string - -// Value return JSONField value -func (j JSONField) Value() string { - return string(j) -} - -// Set the JSONField value -func (j *JSONField) Set(d string) { - *j = JSONField(d) -} - -// String convert JSONField to string -func (j *JSONField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JSONField) FieldType() int { - return TypeJSONField -} - -// SetRaw convert interface string to string -func (j *JSONField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JSONField value -func (j *JSONField) RawValue() interface{} { - return j.Value() -} - -// verify JSONField implement Fielder -var _ Fielder = new(JSONField) - -// JsonbField postgres json field. -type JsonbField string - -// Value return JsonbField value -func (j JsonbField) Value() string { - return string(j) -} - -// Set the JsonbField value -func (j *JsonbField) Set(d string) { - *j = JsonbField(d) -} - -// String convert JsonbField to string -func (j *JsonbField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JsonbField) FieldType() int { - return TypeJsonbField -} - -// SetRaw convert interface string to string -func (j *JsonbField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JsonbField value -func (j *JsonbField) RawValue() interface{} { - return j.Value() -} - -// verify JsonbField implement Fielder -var _ Fielder = new(JsonbField) diff --git a/orm/models_info_f.go b/orm/models_info_f.go deleted file mode 100644 index 7044b0bd..00000000 --- a/orm/models_info_f.go +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -var errSkipField = errors.New("skip field") - -// field info collection -type fields struct { - pk *fieldInfo - columns map[string]*fieldInfo - fields map[string]*fieldInfo - fieldsLow map[string]*fieldInfo - fieldsByType map[int][]*fieldInfo - fieldsRel []*fieldInfo - fieldsReverse []*fieldInfo - fieldsDB []*fieldInfo - rels []*fieldInfo - orders []string - dbcols []string -} - -// add field info -func (f *fields) Add(fi *fieldInfo) (added bool) { - if f.fields[fi.name] == nil && f.columns[fi.column] == nil { - f.columns[fi.column] = fi - f.fields[fi.name] = fi - f.fieldsLow[strings.ToLower(fi.name)] = fi - } else { - return - } - if _, ok := f.fieldsByType[fi.fieldType]; !ok { - f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) - } - f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) - f.orders = append(f.orders, fi.column) - if fi.dbcol { - f.dbcols = append(f.dbcols, fi.column) - f.fieldsDB = append(f.fieldsDB, fi) - } - if fi.rel { - f.fieldsRel = append(f.fieldsRel, fi) - } - if fi.reverse { - f.fieldsReverse = append(f.fieldsReverse, fi) - } - return true -} - -// get field info by name -func (f *fields) GetByName(name string) *fieldInfo { - return f.fields[name] -} - -// get field info by column name -func (f *fields) GetByColumn(column string) *fieldInfo { - return f.columns[column] -} - -// get field info by string, name is prior -func (f *fields) GetByAny(name string) (*fieldInfo, bool) { - if fi, ok := f.fields[name]; ok { - return fi, ok - } - if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { - return fi, ok - } - if fi, ok := f.columns[name]; ok { - return fi, ok - } - return nil, false -} - -// create new field info collection -func newFields() *fields { - f := new(fields) - f.fields = make(map[string]*fieldInfo) - f.fieldsLow = make(map[string]*fieldInfo) - f.columns = make(map[string]*fieldInfo) - f.fieldsByType = make(map[int][]*fieldInfo) - return f -} - -// single field info -type fieldInfo struct { - mi *modelInfo - fieldIndex []int - fieldType int - dbcol bool // table column fk and onetoone - inModel bool - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField - auto bool - pk bool - null bool - index bool - unique bool - colDefault bool // whether has default tag - initial StrTo // store the default value - size int - toText bool - autoNow bool - autoNowAdd bool - rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true - reverse bool - reverseField string - reverseFieldInfo *fieldInfo - reverseFieldInfoTwo *fieldInfo - reverseFieldInfoM2M *fieldInfo - relTable string - relThrough string - relThroughModelInfo *modelInfo - relModelInfo *modelInfo - digits int - decimals int - isFielder bool // implement Fielder interface - onDelete string - description string -} - -// new field info -func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { - var ( - tag string - tagValue string - initial StrTo // store the default value - fieldType int - attrs map[string]bool - tags map[string]string - addrField reflect.Value - ) - - fi = new(fieldInfo) - - // if field which CanAddr is the follow type - // A value is addressable if it is an element of a slice, - // an element of an addressable array, a field of an - // addressable struct, or the result of dereferencing a pointer. - addrField = field - if field.CanAddr() && field.Kind() != reflect.Ptr { - addrField = field.Addr() - if _, ok := addrField.Interface().(Fielder); !ok { - if field.Kind() == reflect.Slice { - addrField = field - } - } - } - - attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) - - if _, ok := attrs["-"]; ok { - return nil, errSkipField - } - - digits := tags["digits"] - decimals := tags["decimals"] - size := tags["size"] - onDelete := tags["on_delete"] - - initial.Clear() - if v, ok := tags["default"]; ok { - initial.Set(v) - } - -checkType: - switch f := addrField.Interface().(type) { - case Fielder: - fi.isFielder = true - if field.Kind() == reflect.Ptr { - err = fmt.Errorf("the model Fielder can not be use ptr") - goto end - } - fieldType = f.FieldType() - if fieldType&IsRelField > 0 { - err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") - goto end - } - default: - tag = "rel" - tagValue = tags[tag] - if tagValue != "" { - switch tagValue { - case "fk": - fieldType = RelForeignKey - break checkType - case "one": - fieldType = RelOneToOne - break checkType - case "m2m": - fieldType = RelManyToMany - if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv - } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv - } - break checkType - default: - err = fmt.Errorf("rel only allow these value: fk, one, m2m") - goto wrongTag - } - } - tag = "reverse" - tagValue = tags[tag] - if tagValue != "" { - switch tagValue { - case "one": - fieldType = RelReverseOne - break checkType - case "many": - fieldType = RelReverseMany - if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv - } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv - } - break checkType - default: - err = fmt.Errorf("reverse only allow these value: one, many") - goto wrongTag - } - } - - fieldType, err = getFieldType(addrField) - if err != nil { - goto end - } - if fieldType == TypeVarCharField { - switch tags["type"] { - case "char": - fieldType = TypeCharField - case "text": - fieldType = TypeTextField - case "json": - fieldType = TypeJSONField - case "jsonb": - fieldType = TypeJsonbField - } - } - if fieldType == TypeFloatField && (digits != "" || decimals != "") { - fieldType = TypeDecimalField - } - if fieldType == TypeDateTimeField && tags["type"] == "date" { - fieldType = TypeDateField - } - if fieldType == TypeTimeField && tags["type"] == "time" { - fieldType = TypeTimeField - } - } - - // check the rel and reverse type - // rel should Ptr - // reverse should slice []*struct - switch fieldType { - case RelForeignKey, RelOneToOne, RelReverseOne: - if field.Kind() != reflect.Ptr { - err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) - goto end - } - case RelManyToMany, RelReverseMany: - if field.Kind() != reflect.Slice { - err = fmt.Errorf("rel/reverse:many field must be slice") - goto end - } else { - if field.Type().Elem().Kind() != reflect.Ptr { - err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) - goto end - } - } - } - - if fieldType&IsFieldType == 0 { - err = fmt.Errorf("wrong field type") - goto end - } - - fi.fieldType = fieldType - fi.name = sf.Name - fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) - fi.addrValue = addrField - fi.sf = sf - fi.fullName = mi.fullName + mName + "." + sf.Name - - fi.description = tags["description"] - fi.null = attrs["null"] - fi.index = attrs["index"] - fi.auto = attrs["auto"] - fi.pk = attrs["pk"] - fi.unique = attrs["unique"] - - // Mark object property if there is attribute "default" in the orm configuration - if _, ok := tags["default"]; ok { - fi.colDefault = true - } - - switch fieldType { - case RelManyToMany, RelReverseMany, RelReverseOne: - fi.null = false - fi.index = false - fi.auto = false - fi.pk = false - fi.unique = false - default: - fi.dbcol = true - } - - switch fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - fi.rel = true - if fieldType == RelOneToOne { - fi.unique = true - } - case RelReverseMany, RelReverseOne: - fi.reverse = true - } - - if fi.rel && fi.dbcol { - switch onDelete { - case odCascade, odDoNothing: - case odSetDefault: - if !initial.Exist() { - err = errors.New("on_delete: set_default need set field a default value") - goto end - } - case odSetNULL: - if !fi.null { - err = errors.New("on_delete: set_null need set field null") - goto end - } - default: - if onDelete == "" { - onDelete = odCascade - } else { - err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) - goto end - } - } - - fi.onDelete = onDelete - } - - switch fieldType { - case TypeBooleanField: - case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: - if size != "" { - v, e := StrTo(size).Int32() - if e != nil { - err = fmt.Errorf("wrong size value `%s`", size) - } else { - fi.size = int(v) - } - } else { - fi.size = 255 - fi.toText = true - } - case TypeTextField: - fi.index = false - fi.unique = false - case TypeTimeField, TypeDateField, TypeDateTimeField: - if attrs["auto_now"] { - fi.autoNow = true - } else if attrs["auto_now_add"] { - fi.autoNowAdd = true - } - case TypeFloatField: - case TypeDecimalField: - d1 := digits - d2 := decimals - v1, er1 := StrTo(d1).Int8() - v2, er2 := StrTo(d2).Int8() - if er1 != nil || er2 != nil { - err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) - goto end - } - fi.digits = int(v1) - fi.decimals = int(v2) - default: - switch { - case fieldType&IsIntegerField > 0: - case fieldType&IsRelField > 0: - } - } - - if fieldType&IsIntegerField == 0 { - if fi.auto { - err = fmt.Errorf("non-integer type cannot set auto") - goto end - } - } - - if fi.auto || fi.pk { - if fi.auto { - switch addrField.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - default: - err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) - goto end - } - fi.pk = true - } - fi.null = false - fi.index = false - fi.unique = false - } - - if fi.unique { - fi.index = false - } - - // can not set default for these type - if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { - initial.Clear() - } - - if initial.Exist() { - v := initial - switch fieldType { - case TypeBooleanField: - _, err = v.Bool() - case TypeFloatField, TypeDecimalField: - _, err = v.Float64() - case TypeBitField: - _, err = v.Int8() - case TypeSmallIntegerField: - _, err = v.Int16() - case TypeIntegerField: - _, err = v.Int32() - case TypeBigIntegerField: - _, err = v.Int64() - case TypePositiveBitField: - _, err = v.Uint8() - case TypePositiveSmallIntegerField: - _, err = v.Uint16() - case TypePositiveIntegerField: - _, err = v.Uint32() - case TypePositiveBigIntegerField: - _, err = v.Uint64() - } - if err != nil { - tag, tagValue = "default", tags["default"] - goto wrongTag - } - } - - fi.initial = initial -end: - if err != nil { - return nil, err - } - return -wrongTag: - return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) -} diff --git a/orm/models_info_m.go b/orm/models_info_m.go deleted file mode 100644 index a4d733b6..00000000 --- a/orm/models_info_m.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" -) - -// single model info -type modelInfo struct { - pkg string - name string - fullName string - table string - model interface{} - fields *fields - manual bool - addrField reflect.Value //store the original struct value - uniques []string - isThrough bool -} - -// new model info -func newModelInfo(val reflect.Value) (mi *modelInfo) { - mi = &modelInfo{} - mi.fields = newFields() - ind := reflect.Indirect(val) - mi.addrField = val - mi.name = ind.Type().Name() - mi.fullName = getFullName(ind.Type()) - addModelFields(mi, ind, "", []int{}) - return -} - -// index: FieldByIndex returns the nested field corresponding to index -func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { - var ( - err error - fi *fieldInfo - sf reflect.StructField - ) - - for i := 0; i < ind.NumField(); i++ { - field := ind.Field(i) - sf = ind.Type().Field(i) - // if the field is unexported skip - if sf.PkgPath != "" { - continue - } - // add anonymous struct fields - if sf.Anonymous { - addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) - continue - } - - fi, err = newFieldInfo(mi, field, sf, mName) - if err == errSkipField { - err = nil - continue - } else if err != nil { - break - } - //record current field index - fi.fieldIndex = append(fi.fieldIndex, index...) - fi.fieldIndex = append(fi.fieldIndex, i) - fi.mi = mi - fi.inModel = true - if !mi.fields.Add(fi) { - err = fmt.Errorf("duplicate column name: %s", fi.column) - break - } - if fi.pk { - if mi.fields.pk != nil { - err = fmt.Errorf("one model must have one pk field only") - break - } else { - mi.fields.pk = fi - } - } - } - - if err != nil { - fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) - os.Exit(2) - } -} - -// combine related model info to new model info. -// prepare for relation models query. -func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { - mi = new(modelInfo) - mi.fields = newFields() - mi.table = m1.table + "_" + m2.table + "s" - mi.name = camelString(mi.table) - mi.fullName = m1.pkg + "." + mi.name - - fa := new(fieldInfo) // pk - f1 := new(fieldInfo) // m1 table RelForeignKey - f2 := new(fieldInfo) // m2 table RelForeignKey - fa.fieldType = TypeBigIntegerField - fa.auto = true - fa.pk = true - fa.dbcol = true - fa.name = "Id" - fa.column = "id" - fa.fullName = mi.fullName + "." + fa.name - - f1.dbcol = true - f2.dbcol = true - f1.fieldType = RelForeignKey - f2.fieldType = RelForeignKey - f1.name = camelString(m1.table) - f2.name = camelString(m2.table) - f1.fullName = mi.fullName + "." + f1.name - f2.fullName = mi.fullName + "." + f2.name - f1.column = m1.table + "_id" - f2.column = m2.table + "_id" - f1.rel = true - f2.rel = true - f1.relTable = m1.table - f2.relTable = m2.table - f1.relModelInfo = m1 - f2.relModelInfo = m2 - f1.mi = mi - f2.mi = mi - - mi.fields.Add(fa) - mi.fields.Add(f1) - mi.fields.Add(f2) - mi.fields.pk = fa - - mi.uniques = []string{f1.column, f2.column} - return -} diff --git a/orm/models_test.go b/orm/models_test.go deleted file mode 100644 index e3a635f2..00000000 --- a/orm/models_test.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "strings" - "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" -) - -// A slice string field. -type SliceStringField []string - -func (e SliceStringField) Value() []string { - return []string(e) -} - -func (e *SliceStringField) Set(d []string) { - *e = SliceStringField(d) -} - -func (e *SliceStringField) Add(v string) { - *e = append(*e, v) -} - -func (e *SliceStringField) String() string { - return strings.Join(e.Value(), ",") -} - -func (e *SliceStringField) FieldType() int { - return TypeVarCharField -} - -func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") - v := make([]string, 0, len(parts)) - for _, p := range parts { - v = append(v, strings.TrimSpace(p)) - } - e.Set(v) - } - default: - return fmt.Errorf(" unknown value `%v`", value) - } - return nil -} - -func (e *SliceStringField) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(SliceStringField) - -// A json field. -type JSONFieldTest struct { - Name string - Data string -} - -func (e *JSONFieldTest) String() string { - data, _ := json.Marshal(e) - return string(data) -} - -func (e *JSONFieldTest) FieldType() int { - return TypeTextField -} - -func (e *JSONFieldTest) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - return json.Unmarshal([]byte(d), e) - default: - return fmt.Errorf(" unknown value `%v`", value) - } -} - -func (e *JSONFieldTest) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(JSONFieldTest) - -type Data struct { - ID int `orm:"column(id)"` - Boolean bool - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - JSON string `orm:"type(json);default({\"name\":\"json\"})"` - Jsonb string `orm:"type(jsonb)"` - Time time.Time `orm:"type(time)"` - Date time.Time `orm:"type(date)"` - DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune - Int int - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 `orm:"digits(8);decimals(4)"` -} - -type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` -} - -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 - -type DataCustom struct { - ID int `orm:"column(id)"` - Boolean Boolean - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - Byte Byte - Rune Rune - Int Int - Int8 Int8 - Int16 Int16 - Int32 Int32 - Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 - Uint32 Uint32 - Uint64 Uint64 - Float32 Float32 - Float64 Float64 - Decimal Float64 `orm:"digits(8);decimals(4)"` -} - -// only for mysql -type UserBig struct { - ID uint64 `orm:"column(id)"` - Name string -} - -type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool -} - -func (u *User) TableIndex() [][]string { - return [][]string{ - {"Id", "UserName"}, - {"Id", "Created"}, - } -} - -func (u *User) TableUnique() [][]string { - return [][]string{ - {"UserName", "Email"}, - } -} - -func NewUser() *User { - obj := new(User) - return obj -} - -type Profile struct { - ID int `orm:"column(id)"` - Age int16 - Money float64 - User *User `orm:"reverse(one)" json:"-"` - BestPost *Post `orm:"rel(one);null"` -} - -func (u *Profile) TableName() string { - return "user_profile" -} - -func NewProfile() *Profile { - obj := new(Profile) - return obj -} - -type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` -} - -func (u *Post) TableIndex() [][]string { - return [][]string{ - {"Id", "Created"}, - } -} - -func NewPost() *Post { - obj := new(Post) - return obj -} - -type Tag struct { - ID int `orm:"column(id)"` - Name string `orm:"size(30)"` - BestPost *Post `orm:"rel(one);null"` - Posts []*Post `orm:"reverse(many)" json:"-"` -} - -func NewTag() *Tag { - obj := new(Tag) - return obj -} - -type PostTags struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk)"` - Tag *Tag `orm:"rel(fk)"` -} - -func (m *PostTags) TableName() string { - return "prefix_post_tags" -} - -type Comment struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk);column(post)"` - Content string `orm:"type(text)"` - Parent *Comment `orm:"null;rel(fk)"` - Created time.Time `orm:"auto_now_add"` -} - -func NewComment() *Comment { - obj := new(Comment) - return obj -} - -type Group struct { - ID int `orm:"column(gid);size(32)"` - Name string - Permissions []*Permission `orm:"reverse(many)" json:"-"` -} - -type Permission struct { - ID int `orm:"column(id)"` - Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` -} - -type GroupPermissions struct { - ID int `orm:"column(id)"` - Group *Group `orm:"rel(fk)"` - Permission *Permission `orm:"rel(fk)"` -} - -type ModelID struct { - ID int64 -} - -type ModelBase struct { - ModelID - - Created time.Time `orm:"auto_now_add;type(datetime)"` - Updated time.Time `orm:"auto_now;type(datetime)"` -} - -type InLine struct { - // Common Fields - ModelBase - - // Other Fields - Name string `orm:"unique"` - Email string -} - -func NewInLine() *InLine { - return new(InLine) -} - -type InLineOneToOne struct { - // Common Fields - ModelBase - - Note string - InLine *InLine `orm:"rel(fk);column(inline)"` -} - -func NewInLineOneToOne() *InLineOneToOne { - return new(InLineOneToOne) -} - -type IntegerPk struct { - ID int64 `orm:"pk"` - Value string -} - -type UintPk struct { - ID uint32 `orm:"pk"` - Name string -} - -type PtrPk struct { - ID *IntegerPk `orm:"pk;rel(one)"` - Positive bool -} - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -var ( - IsMysql = DBARGS.Driver == "mysql" - IsSqlite = DBARGS.Driver == "sqlite3" - IsPostgres = DBARGS.Driver == "postgres" - IsTidb = DBARGS.Driver == "tidb" -) - -var ( - dORM Ormer - dDbBaser dbBaser -) - -var ( - helpinfo = `need driver and source! - - Default DB Drivers. - - driver: url - mysql: https://github.com/go-sql-driver/mysql - sqlite3: https://github.com/mattn/go-sqlite3 - postgres: https://github.com/lib/pq - tidb: https://github.com/pingcap/tidb - - usage: - - go get -u github.com/astaxie/beego/orm - go get -u github.com/go-sql-driver/mysql - go get -u github.com/mattn/go-sqlite3 - go get -u github.com/lib/pq - go get -u github.com/pingcap/tidb - - #### MySQL - mysql -u root -e 'create database orm_test;' - export ORM_DRIVER=mysql - export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm - - - #### Sqlite3 - export ORM_DRIVER=sqlite3 - export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm - - - #### PostgreSQL - psql -c 'create database orm_test;' -U postgres - export ORM_DRIVER=postgres - export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm - - #### TiDB - export ORM_DRIVER=tidb - export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm - - ` -) - -func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() - - if DBARGS.Driver == "" || DBARGS.Source == "" { - fmt.Println(helpinfo) - os.Exit(2) - } - - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) - - alias := getDbAlias("default") - if alias.Driver == DRMySQL { - alias.Engine = "INNODB" - } - -} diff --git a/orm/models_utils.go b/orm/models_utils.go deleted file mode 100644 index 71127a6b..00000000 --- a/orm/models_utils.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "strings" - "time" -) - -// 1 is attr -// 2 is tag -var supportTag = map[string]int{ - "-": 1, - "null": 1, - "index": 1, - "unique": 1, - "pk": 1, - "auto": 1, - "auto_now": 1, - "auto_now_add": 1, - "size": 2, - "column": 2, - "default": 2, - "rel": 2, - "reverse": 2, - "rel_table": 2, - "rel_through": 2, - "digits": 2, - "decimals": 2, - "on_delete": 2, - "type": 2, - "description": 2, -} - -// get reflect.Type name with package path. -func getFullName(typ reflect.Type) string { - return typ.PkgPath() + "." + typ.Name() -} - -// getTableName get struct table name. -// If the struct implement the TableName, then get the result as tablename -// else use the struct name which will apply snakeString. -func getTableName(val reflect.Value) string { - if fun := val.MethodByName("TableName"); fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - // has return and the first val is string - if len(vals) > 0 && vals[0].Kind() == reflect.String { - return vals[0].String() - } - } - return snakeString(reflect.Indirect(val).Type().Name()) -} - -// get table engine, myisam or innodb. -func getTableEngine(val reflect.Value) string { - fun := val.MethodByName("TableEngine") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].Kind() == reflect.String { - return vals[0].String() - } - } - return "" -} - -// get table index from method. -func getTableIndex(val reflect.Value) [][]string { - fun := val.MethodByName("TableIndex") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].CanInterface() { - if d, ok := vals[0].Interface().([][]string); ok { - return d - } - } - } - return nil -} - -// get table unique from method -func getTableUnique(val reflect.Value) [][]string { - fun := val.MethodByName("TableUnique") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].CanInterface() { - if d, ok := vals[0].Interface().([][]string); ok { - return d - } - } - } - return nil -} - -// get snaked column name -func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { - column := col - if col == "" { - column = nameStrategyMap[nameStrategy](sf.Name) - } - switch ft { - case RelForeignKey, RelOneToOne: - if len(col) == 0 { - column = column + "_id" - } - case RelManyToMany, RelReverseMany, RelReverseOne: - column = sf.Name - } - return column -} - -// return field type as type constant from reflect.Value -func getFieldType(val reflect.Value) (ft int, err error) { - switch val.Type() { - case reflect.TypeOf(new(int8)): - ft = TypeBitField - case reflect.TypeOf(new(int16)): - ft = TypeSmallIntegerField - case reflect.TypeOf(new(int32)), - reflect.TypeOf(new(int)): - ft = TypeIntegerField - case reflect.TypeOf(new(int64)): - ft = TypeBigIntegerField - case reflect.TypeOf(new(uint8)): - ft = TypePositiveBitField - case reflect.TypeOf(new(uint16)): - ft = TypePositiveSmallIntegerField - case reflect.TypeOf(new(uint32)), - reflect.TypeOf(new(uint)): - ft = TypePositiveIntegerField - case reflect.TypeOf(new(uint64)): - ft = TypePositiveBigIntegerField - case reflect.TypeOf(new(float32)), - reflect.TypeOf(new(float64)): - ft = TypeFloatField - case reflect.TypeOf(new(bool)): - ft = TypeBooleanField - case reflect.TypeOf(new(string)): - ft = TypeVarCharField - case reflect.TypeOf(new(time.Time)): - ft = TypeDateTimeField - default: - elm := reflect.Indirect(val) - switch elm.Kind() { - case reflect.Int8: - ft = TypeBitField - case reflect.Int16: - ft = TypeSmallIntegerField - case reflect.Int32, reflect.Int: - ft = TypeIntegerField - case reflect.Int64: - ft = TypeBigIntegerField - case reflect.Uint8: - ft = TypePositiveBitField - case reflect.Uint16: - ft = TypePositiveSmallIntegerField - case reflect.Uint32, reflect.Uint: - ft = TypePositiveIntegerField - case reflect.Uint64: - ft = TypePositiveBigIntegerField - case reflect.Float32, reflect.Float64: - ft = TypeFloatField - case reflect.Bool: - ft = TypeBooleanField - case reflect.String: - ft = TypeVarCharField - default: - if elm.Interface() == nil { - panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) - } - switch elm.Interface().(type) { - case sql.NullInt64: - ft = TypeBigIntegerField - case sql.NullFloat64: - ft = TypeFloatField - case sql.NullBool: - ft = TypeBooleanField - case sql.NullString: - ft = TypeVarCharField - case time.Time: - ft = TypeDateTimeField - } - } - } - if ft&IsFieldType == 0 { - err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) - } - return -} - -// parse struct tag string -func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { - attrs = make(map[string]bool) - tags = make(map[string]string) - for _, v := range strings.Split(data, defaultStructTagDelim) { - if v == "" { - continue - } - v = strings.TrimSpace(v) - if t := strings.ToLower(v); supportTag[t] == 1 { - attrs[t] = true - } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { - name := t[:i] - if supportTag[name] == 2 { - v = v[i+1 : len(v)-1] - tags[name] = v - } - } else { - DebugLog.Println("unsupport orm tag", v) - } - } - return -} diff --git a/orm/orm.go b/orm/orm.go deleted file mode 100644 index c7566b9a..00000000 --- a/orm/orm.go +++ /dev/null @@ -1,602 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -// Package orm provide ORM for MySQL/PostgreSQL/sqlite -// Simple Usage -// -// package main -// -// import ( -// "fmt" -// "github.com/astaxie/beego/orm" -// _ "github.com/go-sql-driver/mysql" // import your used driver -// ) -// -// // Model Struct -// type User struct { -// Id int `orm:"auto"` -// Name string `orm:"size(100)"` -// } -// -// func init() { -// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) -// } -// -// func main() { -// o := orm.NewOrm() -// user := User{Name: "slene"} -// // insert -// id, err := o.Insert(&user) -// // update -// user.Name = "astaxie" -// num, err := o.Update(&user) -// // read one -// u := User{Id: user.Id} -// err = o.Read(&u) -// // delete -// num, err = o.Delete(&u) -// } -// -// more docs: http://beego.me/docs/mvc/model/overview.md -package orm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "os" - "reflect" - "sync" - "time" -) - -// DebugQueries define the debug -const ( - DebugQueries = iota -) - -// Define common vars -var ( - Debug = false - DebugLog = NewLog(os.Stdout) - DefaultRowsLimit = -1 - DefaultRelsDepth = 2 - DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") - ErrMultiRows = errors.New(" return multi rows") - ErrNoRows = errors.New(" no row found") - ErrStmtClosed = errors.New(" stmt already closed") - ErrArgs = errors.New(" args error may be empty") - ErrNotImplement = errors.New("have not implement") -) - -// Params stores the Params -type Params map[string]interface{} - -// ParamsList stores paramslist -type ParamsList []interface{} - -type orm struct { - alias *alias - db dbQuerier - isTx bool -} - -var _ Ormer = new(orm) - -// get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { - val := reflect.ValueOf(md) - ind = reflect.Indirect(val) - typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - name := getFullName(typ) - if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind - } - panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) -} - -// get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { - fi, ok := mi.fields.GetByAny(name) - if !ok { - panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) - } - return fi -} - -// read data to model -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Read(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) -} - -// read data to model, like Read(), but use "SELECT FOR UPDATE" form -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) -} - -// Try to read a row from the database, or insert one if it doesn't exist -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { - cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) - if err == ErrNoRows { - // Create - id, err := o.Insert(md) - return (err == nil), id, err - } - - id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - id = int64(vid.Uint()) - } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) - } else { - id = vid.Int() - } - - return false, id, err -} - -// insert model data to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Insert(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) - } - } -} - -// insert some models to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { - var cnt int64 - - sind := reflect.Indirect(reflect.ValueOf(mds)) - - switch sind.Kind() { - case reflect.Array, reflect.Slice: - if sind.Len() == 0 { - return cnt, ErrArgs - } - default: - return cnt, ErrArgs - } - - if bulk <= 1 { - for i := 0; i < sind.Len(); i++ { - ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return cnt, err - } - - o.setPk(mi, ind, id) - - cnt++ - } - } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) - } - return cnt, nil -} - -// InsertOrUpdate data to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// update model to database. -// cols set the columns those want to update. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) -} - -// delete model in database -// cols shows the delete conditions values read from. default is pk -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - if num > 0 { - o.setPk(mi, ind, 0) - } - return num, nil -} - -// create a models to models queryer -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - switch { - case fi.fieldType == RelManyToMany: - case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: - default: - panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) - } - - return newQueryM2M(md, o, mi, fi, ind) -} - -// load related models to md model. -// args are limit, offset int and order string. -// -// example: -// orm.LoadRelated(post,"Tags") -// for _,tag := range post.Tags{...} -// -// make sure the relation is defined in model struct tags. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) - - var relDepth int - var limit, offset int64 - var order string - for i, arg := range args { - switch i { - case 0: - if v, ok := arg.(bool); ok { - if v { - relDepth = DefaultRelsDepth - } - } else if v, ok := arg.(int); ok { - relDepth = v - } - case 1: - limit = ToInt64(arg) - case 2: - offset = ToInt64(arg) - case 3: - order, _ = arg.(string) - } - } - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - limit = 1 - offset = 0 - } - - qs.limit = limit - qs.offset = offset - qs.relDepth = relDepth - - if len(order) > 0 { - qs.orders = []string{order} - } - - find := ind.FieldByIndex(fi.fieldIndex) - - var nums int64 - var err error - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - val := reflect.New(find.Type().Elem()) - container := val.Interface() - err = qs.One(container) - if err == nil { - find.Set(val) - nums = 1 - } - default: - nums, err = qs.All(find.Addr().Interface()) - } - - return nums, err -} - -// return a QuerySeter for related models to md model. -// it can do all, update, delete in QuerySeter. -// example: -// qs := orm.QueryRelated(post,"Tag") -// qs.All(&[]*Tag{}) -// -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { - // is this api needed ? - _, _, _, qs := o.queryRelated(md, name) - return qs -} - -// get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - _, _, exist := getExistPk(mi, ind) - if !exist { - panic(ErrMissPK) - } - - var qs *querySet - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - if !fi.inModel { - break - } - qs = o.getRelQs(md, mi, fi) - case RelReverseOne, RelReverseMany: - if !fi.inModel { - break - } - qs = o.getReverseQs(md, mi, fi) - } - - if qs == nil { - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) - } - - return mi, fi, ind, qs -} - -// get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelReverseOne, RelReverseMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) - } - - var q *querySet - - if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { - q = newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) - } - - q := newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition() - - if fi.fieldType == RelManyToMany { - q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q.cond = q.cond.And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// return a QuerySeter for table operations. -// table name can be string or struct. -// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - var name string - if table, ok := ptrStructOrTableName.(string); ok { - name = nameStrategyMap[defaultNameStrategy](table) - if mi, ok := modelCache.get(name); ok { - qs = newQuerySet(o, mi) - } - } else { - name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) - if mi, ok := modelCache.getByFullName(name); ok { - qs = newQuerySet(o, mi) - } - } - if qs == nil { - panic(fmt.Errorf(" table name: `%s` not exists", name)) - } - return -} - -// switch to another registered database driver by given name. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -// Using NewOrmUsingDB(name) -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) - } - if al, ok := dataBaseCache.get(name); ok { - o.alias = al - if Debug { - o.db = newDbQueryLog(al, al.DB) - } else { - o.db = al.DB - } - } else { - return fmt.Errorf(" unknown db alias name `%s`", name) - } - return nil -} - -// begin transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.DB.Stats() - return &stats - } - return nil -} - -// NewOrm create new orm -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } - return o -} - -// NewOrmWithDB create a new ormer object with specify *sql.DB for query -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - detectTZ(al) - - o := new(orm) - o.alias = al - - if Debug { - o.db = newDbQueryLog(o.alias, db) - } else { - o.db = db - } - - return o, nil -} diff --git a/orm/orm_conds.go b/orm/orm_conds.go deleted file mode 100644 index f3fd66f0..00000000 --- a/orm/orm_conds.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" -) - -// ExprSep define the expression separation -const ( - ExprSep = "__" -) - -type condValue struct { - exprs []string - args []interface{} - cond *Condition - isOr bool - isNot bool - isCond bool - isRaw bool - sql string -} - -// Condition struct. -// work for WHERE conditions. -type Condition struct { - params []condValue -} - -// NewCondition return new condition struct -func NewCondition() *Condition { - c := &Condition{} - return c -} - -// Raw add raw sql to condition -func (c Condition) Raw(expr string, sql string) *Condition { - if len(sql) == 0 { - panic(fmt.Errorf(" sql cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) - return &c -} - -// And add expression to condition -func (c Condition) And(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) - return &c -} - -// AndNot add NOT expression to condition -func (c Condition) AndNot(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) - return &c -} - -// AndCond combine a condition to current condition -func (c *Condition) AndCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true}) - } - return c -} - -// AndNotCond combine a AND NOT condition to current condition -func (c *Condition) AndNotCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) - } - return c -} - -// Or add OR expression to condition -func (c Condition) Or(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) - return &c -} - -// OrNot add OR NOT expression to condition -func (c Condition) OrNot(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) - return &c -} - -// OrCond combine a OR condition to current condition -func (c *Condition) OrCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) - } - return c -} - -// OrNotCond combine a OR NOT condition to current condition -func (c *Condition) OrNotCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) - } - return c -} - -// IsEmpty check the condition arguments are empty or not. -func (c *Condition) IsEmpty() bool { - return len(c.params) == 0 -} - -// clone clone a condition -func (c Condition) clone() *Condition { - return &c -} diff --git a/orm/orm_log.go b/orm/orm_log.go deleted file mode 100644 index 5bb3a24f..00000000 --- a/orm/orm_log.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "fmt" - "io" - "log" - "strings" - "time" -) - -// Log implement the log.Logger -type Log struct { - *log.Logger -} - -//costomer log func -var LogFunc func(query map[string]interface{}) - -// NewLog set io.Writer to create a Logger. -func NewLog(out io.Writer) *Log { - d := new(Log) - d.Logger = log.New(out, "[ORM]", log.LstdFlags) - return d -} - -func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { - var logMap = make(map[string]interface{}) - sub := time.Now().Sub(t) / 1e5 - elsp := float64(int(sub)) / 10.0 - logMap["cost_time"] = elsp - flag := " OK" - if err != nil { - flag = "FAIL" - } - logMap["flag"] = flag - con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) - cons := make([]string, 0, len(args)) - for _, arg := range args { - cons = append(cons, fmt.Sprintf("%v", arg)) - } - if len(cons) > 0 { - con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) - } - if err != nil { - con += " - " + err.Error() - } - logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil { - LogFunc(logMap) - } - DebugLog.Println(con) -} - -// statement query logger struct. -// if dev mode, use stmtQueryLog, or use stmtQuerier. -type stmtQueryLog struct { - alias *alias - query string - stmt stmtQuerier -} - -var _ stmtQuerier = new(stmtQueryLog) - -func (d *stmtQueryLog) Close() error { - a := time.Now() - err := d.stmt.Close() - debugLogQueies(d.alias, "st.Close", d.query, a, err) - return err -} - -func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.stmt.Exec(args...) - debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) - return res, err -} - -func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.stmt.Query(args...) - debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) - return res, err -} - -func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { - a := time.Now() - res := d.stmt.QueryRow(args...) - debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) - return res -} - -func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { - d := new(stmtQueryLog) - d.stmt = stmt - d.alias = alias - d.query = query - return d -} - -// database query logger struct. -// if dev mode, use dbQueryLog, or use dbQuerier. -type dbQueryLog struct { - alias *alias - db dbQuerier - tx txer - txe txEnder -} - -var _ dbQuerier = new(dbQueryLog) -var _ txer = new(dbQueryLog) -var _ txEnder = new(dbQueryLog) - -func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.Prepare(query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err -} - -func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.PrepareContext(ctx, query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err -} - -func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.Exec(query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.ExecContext(ctx, query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.Query(query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.QueryContext(ctx, query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRow(query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res -} - -func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRowContext(ctx, query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res -} - -func (d *dbQueryLog) Begin() (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).Begin() - debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) - return tx, err -} - -func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).BeginTx(ctx, opts) - debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) - return tx, err -} - -func (d *dbQueryLog) Commit() error { - a := time.Now() - err := d.db.(txEnder).Commit() - debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) - return err -} - -func (d *dbQueryLog) Rollback() error { - a := time.Now() - err := d.db.(txEnder).Rollback() - debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) - return err -} - -func (d *dbQueryLog) SetDB(db dbQuerier) { - d.db = db -} - -func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { - d := new(dbQueryLog) - d.alias = alias - d.db = db - return d -} diff --git a/orm/orm_object.go b/orm/orm_object.go deleted file mode 100644 index de3181ce..00000000 --- a/orm/orm_object.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" -) - -// an insert queryer struct -type insertSet struct { - mi *modelInfo - orm *orm - stmt stmtQuerier - closed bool -} - -var _ Inserter = new(insertSet) - -// insert model ignore it's registered or not. -func (o *insertSet) Insert(md interface{}) (int64, error) { - if o.closed { - return 0, ErrStmtClosed - } - val := reflect.ValueOf(md) - ind := reflect.Indirect(val) - typ := ind.Type() - name := getFullName(typ) - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) - } - if name != o.mi.fullName { - panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) - } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) - if err != nil { - return id, err - } - if id > 0 { - if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) - } - } - } - return id, nil -} - -// close insert queryer statement -func (o *insertSet) Close() error { - if o.closed { - return ErrStmtClosed - } - o.closed = true - return o.stmt.Close() -} - -// create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { - bi := new(insertSet) - bi.orm = orm - bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) - if err != nil { - return nil, err - } - if Debug { - bi.stmt = newStmtQueryLog(orm.alias, st, query) - } else { - bi.stmt = st - } - return bi, nil -} diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go deleted file mode 100644 index 6a270a0d..00000000 --- a/orm/orm_querym2m.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import "reflect" - -// model to model struct -type queryM2M struct { - md interface{} - mi *modelInfo - fi *fieldInfo - qs *querySet - ind reflect.Value -} - -// add models to origin models when creating queryM2M. -// example: -// m2m := orm.QueryM2M(post,"Tag") -// m2m.Add(&Tag1{},&Tag2{}) -// for _,tag := range post.Tags{} -// -// make sure the relation is defined in post model struct tag. -func (o *queryM2M) Add(mds ...interface{}) (int64, error) { - fi := o.fi - mi := fi.relThroughModelInfo - mfi := fi.reverseFieldInfo - rfi := fi.reverseFieldInfoTwo - - orm := o.qs.orm - dbase := orm.alias.DbBaser - - var models []interface{} - var otherValues []interface{} - var otherNames []string - - for _, colname := range mi.fields.dbcols { - if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && - mi.fields.columns[colname] != mi.fields.pk { - otherNames = append(otherNames, colname) - } - } - for i, md := range mds { - if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { - otherValues = append(otherValues, md) - mds = append(mds[:i], mds[i+1:]...) - } - } - for _, md := range mds { - val := reflect.ValueOf(md) - if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - if v.CanInterface() { - models = append(models, v.Interface()) - } - } - } else { - models = append(models, md) - } - } - - _, v1, exist := getExistPk(o.mi, o.ind) - if !exist { - panic(ErrMissPK) - } - - names := []string{mfi.column, rfi.column} - - values := make([]interface{}, 0, len(models)*2) - for _, md := range models { - - ind := reflect.Indirect(reflect.ValueOf(md)) - var v2 interface{} - if ind.Kind() != reflect.Struct { - v2 = ind.Interface() - } else { - _, v2, exist = getExistPk(fi.relModelInfo, ind) - if !exist { - panic(ErrMissPK) - } - } - values = append(values, v1, v2) - - } - names = append(names, otherNames...) - values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) -} - -// remove models following the origin model relationship -func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { - fi := o.fi - qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) - - return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() -} - -// check model is existed in relationship of origin model -func (o *queryM2M) Exist(md interface{}) bool { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() -} - -// clean all models in related of origin model -func (o *queryM2M) Clear() (int64, error) { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() -} - -// count all related models of origin model -func (o *queryM2M) Count() (int64, error) { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() -} - -var _ QueryM2Mer = new(queryM2M) - -// create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { - qm2m := new(queryM2M) - qm2m.md = md - qm2m.mi = mi - qm2m.fi = fi - qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) - return qm2m -} diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go deleted file mode 100644 index 878b836b..00000000 --- a/orm/orm_queryset.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "fmt" -) - -type colValue struct { - value int64 - opt operator -} - -type operator int - -// define Col operations -const ( - ColAdd operator = iota - ColMinus - ColMultiply - ColExcept - ColBitAnd - ColBitRShift - ColBitLShift - ColBitXOR - ColBitOr -) - -// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: -// Params{ -// "Nums": ColValue(Col_Add, 10), -// } -func ColValue(opt operator, value interface{}) interface{} { - switch opt { - case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, - ColBitLShift, ColBitXOR, ColBitOr: - default: - panic(fmt.Errorf("orm.ColValue wrong operator")) - } - v, err := StrTo(ToStr(value)).Int64() - if err != nil { - panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) - } - var val colValue - val.value = v - val.opt = opt - return val -} - -// real query struct -type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forupdate bool - orm *orm - ctx context.Context - forContext bool -} - -var _ QuerySeter = new(querySet) - -// add condition expression to QuerySeter. -func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.And(expr, args...) - return &o -} - -// add raw sql to querySeter. -func (o querySet) FilterRaw(expr string, sql string) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.Raw(expr, sql) - return &o -} - -// add NOT condition to querySeter. -func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.AndNot(expr, args...) - return &o -} - -// set offset number -func (o *querySet) setOffset(num interface{}) { - o.offset = ToInt64(num) -} - -// add LIMIT value. -// args[0] means offset, e.g. LIMIT num,offset. -func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { - o.limit = ToInt64(limit) - if len(args) > 0 { - o.setOffset(args[0]) - } - return &o -} - -// add OFFSET value -func (o querySet) Offset(offset interface{}) QuerySeter { - o.setOffset(offset) - return &o -} - -// add GROUP expression -func (o querySet) GroupBy(exprs ...string) QuerySeter { - o.groups = exprs - return &o -} - -// add ORDER expression. -// "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs - return &o -} - -// add DISTINCT to SELECT -func (o querySet) Distinct() QuerySeter { - o.distinct = true - return &o -} - -// add FOR UPDATE to SELECT -func (o querySet) ForUpdate() QuerySeter { - o.forupdate = true - return &o -} - -// set relation model to query together. -// it will query relation models and assign to parent model. -func (o querySet) RelatedSel(params ...interface{}) QuerySeter { - if len(params) == 0 { - o.relDepth = DefaultRelsDepth - } else { - for _, p := range params { - switch val := p.(type) { - case string: - o.related = append(o.related, val) - case int: - o.relDepth = val - default: - panic(fmt.Errorf(" wrong param kind: %v", val)) - } - } - } - return &o -} - -// set condition to QuerySeter. -func (o querySet) SetCond(cond *Condition) QuerySeter { - o.cond = cond - return &o -} - -// get condition from QuerySeter -func (o querySet) GetCond() *Condition { - return o.cond -} - -// return QuerySeter execution result number -func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// check result empty or not after QuerySeter executed -func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) - return cnt > 0 -} - -// execute update with parameters -func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) -} - -// execute delete -func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// return a insert queryer. -// it can be used in times. -// example: -// i,err := sq.PrepareInsert() -// i.Add(&user1{},&user2{}) -func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) -} - -// query all data and map to containers. -// cols means the columns when querying. -func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) -} - -// query one row data and map to containers. -// cols means the columns when querying. -func (o *querySet) One(container interface{}, cols ...string) error { - o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) - if err != nil { - return err - } - if num == 0 { - return ErrNoRows - } - - if num > 1 { - return ErrMultiRows - } - return nil -} - -// query all data and map to []map[string]interface. -// expres means condition expression. -// it converts data to []map[column]value. -func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to [][]interface -// it converts data to [][column_index]value -func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to []interface. -// it's designed for one row record set, auto change to []value, not [][column]value. -func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) -} - -// query all rows into map[string]interface with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } -func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// query all rows into struct with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to struct { -// Total int -// Found int -// } -func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - -// create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { - o := new(querySet) - o.mi = mi - o.orm = orm - return o -} diff --git a/orm/orm_raw.go b/orm/orm_raw.go deleted file mode 100644 index 3325a7ea..00000000 --- a/orm/orm_raw.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "time" -) - -// raw sql string prepared statement -type rawPrepare struct { - rs *rawSet - stmt stmtQuerier - closed bool -} - -func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { - if o.closed { - return nil, ErrStmtClosed - } - return o.stmt.Exec(args...) -} - -func (o *rawPrepare) Close() error { - o.closed = true - return o.stmt.Close() -} - -func newRawPreparer(rs *rawSet) (RawPreparer, error) { - o := new(rawPrepare) - o.rs = rs - - query := rs.query - rs.orm.alias.DbBaser.ReplaceMarks(&query) - - st, err := rs.orm.db.Prepare(query) - if err != nil { - return nil, err - } - if Debug { - o.stmt = newStmtQueryLog(rs.orm.alias, st, query) - } else { - o.stmt = st - } - return o, nil -} - -// raw query seter -type rawSet struct { - query string - args []interface{} - orm *orm -} - -var _ RawSeter = new(rawSet) - -// set args for every query -func (o rawSet) SetArgs(args ...interface{}) RawSeter { - o.args = args - return &o -} - -// execute raw sql and return sql.Result -func (o *rawSet) Exec() (sql.Result, error) { - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - return o.orm.db.Exec(query, args...) -} - -// set field value to row container -func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { - switch ind.Kind() { - case reflect.Bool: - if value == nil { - ind.SetBool(false) - } else if v, ok := value.(bool); ok { - ind.SetBool(v) - } else { - v, _ := StrTo(ToStr(value)).Bool() - ind.SetBool(v) - } - - case reflect.String: - if value == nil { - ind.SetString("") - } else { - ind.SetString(ToStr(value)) - } - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if value == nil { - ind.SetInt(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - ind.SetInt(val.Int()) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - ind.SetInt(int64(val.Uint())) - default: - v, _ := StrTo(ToStr(value)).Int64() - ind.SetInt(v) - } - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if value == nil { - ind.SetUint(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - ind.SetUint(uint64(val.Int())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - ind.SetUint(val.Uint()) - default: - v, _ := StrTo(ToStr(value)).Uint64() - ind.SetUint(v) - } - } - case reflect.Float64, reflect.Float32: - if value == nil { - ind.SetFloat(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Float64: - ind.SetFloat(val.Float()) - default: - v, _ := StrTo(ToStr(value)).Float64() - ind.SetFloat(v) - } - } - - case reflect.Struct: - if value == nil { - ind.Set(reflect.Zero(ind.Type())) - return - } - switch ind.Interface().(type) { - case time.Time: - var str string - switch d := value.(type) { - case time.Time: - o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) - ind.Set(reflect.ValueOf(d)) - case []byte: - str = string(d) - case string: - str = d - } - if str != "" { - if len(str) >= 19 { - str = str[:19] - t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) - if err == nil { - t = t.In(DefaultTimeLoc) - ind.Set(reflect.ValueOf(t)) - } - } else if len(str) >= 10 { - str = str[:10] - t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) - if err == nil { - ind.Set(reflect.ValueOf(t)) - } - } - } - case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: - indi := reflect.New(ind.Type()).Interface() - sc, ok := indi.(sql.Scanner) - if !ok { - return - } - err := sc.Scan(value) - if err == nil { - ind.Set(reflect.Indirect(reflect.ValueOf(sc))) - } - } - - case reflect.Ptr: - if value == nil { - ind.Set(reflect.Zero(ind.Type())) - break - } - ind.Set(reflect.New(ind.Type().Elem())) - o.setFieldValue(reflect.Indirect(ind), value) - } -} - -// set field value in loop for slice container -func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { - nInds := *nIndsPtr - - cur := 0 - for i := 0; i < len(sInds); i++ { - sInd := sInds[i] - eTyp := eTyps[i] - - typ := eTyp - isPtr := false - if typ.Kind() == reflect.Ptr { - isPtr = true - typ = typ.Elem() - } - if typ.Kind() == reflect.Ptr { - isPtr = true - typ = typ.Elem() - } - - var nInd reflect.Value - if init { - nInd = reflect.New(sInd.Type()).Elem() - } else { - nInd = nInds[i] - } - - val := reflect.New(typ) - ind := val.Elem() - - tpName := ind.Type().String() - - if ind.Kind() == reflect.Struct { - if tpName == "time.Time" { - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if isPtr && value == nil { - val = reflect.New(val.Type()).Elem() - } else { - o.setFieldValue(ind, value) - } - cur++ - } - - } else { - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if isPtr && value == nil { - val = reflect.New(val.Type()).Elem() - } else { - o.setFieldValue(ind, value) - } - cur++ - } - - if nInd.Kind() == reflect.Slice { - if isPtr { - nInd = reflect.Append(nInd, val) - } else { - nInd = reflect.Append(nInd, ind) - } - } else { - if isPtr { - nInd.Set(val) - } else { - nInd.Set(ind) - } - } - - nInds[i] = nInd - } -} - -// query data and map to container -func (o *rawSet) QueryRow(containers ...interface{}) error { - var ( - refs = make([]interface{}, 0, len(containers)) - sInds []reflect.Value - eTyps []reflect.Type - sMi *modelInfo - ) - structMode := false - for _, container := range containers { - val := reflect.ValueOf(container) - ind := reflect.Indirect(val) - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" all args must be use ptr")) - } - - etyp := ind.Type() - typ := etyp - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - sInds = append(sInds, ind) - eTyps = append(eTyps, etyp) - - if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { - if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) - } - - structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { - sMi = mi - } - } else { - var ref interface{} - refs = append(refs, &ref) - } - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - rows, err := o.orm.db.Query(query, args...) - if err != nil { - if err == sql.ErrNoRows { - return ErrNoRows - } - return err - } - - defer rows.Close() - - if rows.Next() { - if structMode { - columns, err := rows.Columns() - if err != nil { - return err - } - - columnsMp := make(map[string]interface{}, len(columns)) - - refs = make([]interface{}, 0, len(columns)) - for _, col := range columns { - var ref interface{} - columnsMp[col] = &ref - refs = append(refs, &ref) - } - - if err := rows.Scan(refs...); err != nil { - return err - } - - ind := sInds[0] - - if ind.Kind() == reflect.Ptr { - if ind.IsNil() || !ind.IsValid() { - ind.Set(reflect.New(eTyps[0].Elem())) - } - ind = ind.Elem() - } - - if sMi != nil { - for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { - value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - } - o.setFieldValue(field, value) - } - } - } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) - } - } - } - - } else { - if err := rows.Scan(refs...); err != nil { - return err - } - - nInds := make([]reflect.Value, len(sInds)) - o.loopSetRefs(refs, sInds, &nInds, eTyps, true) - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) - } - } - - } else { - return ErrNoRows - } - - return nil -} - -// query data rows and map to container -func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { - var ( - refs = make([]interface{}, 0, len(containers)) - sInds []reflect.Value - eTyps []reflect.Type - sMi *modelInfo - ) - structMode := false - for _, container := range containers { - val := reflect.ValueOf(container) - sInd := reflect.Indirect(val) - if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { - panic(fmt.Errorf(" all args must be use ptr slice")) - } - - etyp := sInd.Type().Elem() - typ := etyp - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - sInds = append(sInds, sInd) - eTyps = append(eTyps, etyp) - - if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { - if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) - } - - structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { - sMi = mi - } - } else { - var ref interface{} - refs = append(refs, &ref) - } - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - rows, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rows.Close() - - var cnt int64 - nInds := make([]reflect.Value, len(sInds)) - sInd := sInds[0] - - for rows.Next() { - - if structMode { - columns, err := rows.Columns() - if err != nil { - return 0, err - } - - columnsMp := make(map[string]interface{}, len(columns)) - - refs = make([]interface{}, 0, len(columns)) - for _, col := range columns { - var ref interface{} - columnsMp[col] = &ref - refs = append(refs, &ref) - } - - if err := rows.Scan(refs...); err != nil { - return 0, err - } - - if cnt == 0 && !sInd.IsNil() { - sInd.Set(reflect.New(sInd.Type()).Elem()) - } - - var ind reflect.Value - if eTyps[0].Kind() == reflect.Ptr { - ind = reflect.New(eTyps[0].Elem()) - } else { - ind = reflect.New(eTyps[0]) - } - - if ind.Kind() == reflect.Ptr { - ind = ind.Elem() - } - - if sMi != nil { - for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { - value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - } - o.setFieldValue(field, value) - } - } - } else { - // define recursive function - var recursiveSetField func(rv reflect.Value) - recursiveSetField = func(rv reflect.Value) { - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - fe := rv.Type().Field(i) - - // check if the field is a Struct - // recursive the Struct type - if fe.Type.Kind() == reflect.Struct { - recursiveSetField(f) - } - - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) - } - } - } - - // init call the recursive function - recursiveSetField(ind) - } - - if eTyps[0].Kind() == reflect.Ptr { - ind = ind.Addr() - } - - sInd = reflect.Append(sInd, ind) - - } else { - if err := rows.Scan(refs...); err != nil { - return 0, err - } - - o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) - } - - cnt++ - } - - if cnt > 0 { - - if structMode { - sInds[0].Set(sInd) - } else { - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) - } - } - } - - return cnt, nil -} - -func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { - var ( - maps []Params - lists []ParamsList - list ParamsList - ) - - typ := 0 - switch container.(type) { - case *[]Params: - typ = 1 - case *[]ParamsList: - typ = 2 - case *ParamsList: - typ = 3 - default: - panic(fmt.Errorf(" unsupport read values type `%T`", container)) - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - - var rs *sql.Rows - rs, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rs.Close() - - var ( - refs []interface{} - cnt int64 - cols []string - indexs []int - ) - - for rs.Next() { - if cnt == 0 { - columns, err := rs.Columns() - if err != nil { - return 0, err - } - if len(needCols) > 0 { - indexs = make([]int, 0, len(needCols)) - } else { - indexs = make([]int, 0, len(columns)) - } - - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - var ref sql.NullString - refs[i] = &ref - - if len(needCols) > 0 { - for _, c := range needCols { - if c == cols[i] { - indexs = append(indexs, i) - } - } - } else { - indexs = append(indexs, i) - } - } - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - switch typ { - case 1: - params := make(Params, len(cols)) - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - params[cols[i]] = value.String - } else { - params[cols[i]] = nil - } - } - maps = append(maps, params) - case 2: - params := make(ParamsList, 0, len(cols)) - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - params = append(params, value.String) - } else { - params = append(params, nil) - } - } - lists = append(lists, params) - case 3: - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - list = append(list, value.String) - } else { - list = append(list, nil) - } - } - } - - cnt++ - } - - switch v := container.(type) { - case *[]Params: - *v = maps - case *[]ParamsList: - *v = lists - case *ParamsList: - *v = list - } - - return cnt, nil -} - -func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { - var ( - maps Params - ind *reflect.Value - ) - - var typ int - switch container.(type) { - case *Params: - typ = 1 - default: - typ = 2 - vl := reflect.ValueOf(container) - id := reflect.Indirect(vl) - if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { - panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) - } - - ind = &id - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - - rs, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rs.Close() - - var ( - refs []interface{} - cnt int64 - cols []string - ) - - var ( - keyIndex = -1 - valueIndex = -1 - ) - - for rs.Next() { - if cnt == 0 { - columns, err := rs.Columns() - if err != nil { - return 0, err - } - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - if keyCol == cols[i] { - keyIndex = i - } - if typ == 1 || keyIndex == i { - var ref sql.NullString - refs[i] = &ref - } else { - var ref interface{} - refs[i] = &ref - } - if valueCol == cols[i] { - valueIndex = i - } - } - if keyIndex == -1 || valueIndex == -1 { - panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) - } - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - if cnt == 0 { - switch typ { - case 1: - maps = make(Params) - } - } - - key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String - - switch typ { - case 1: - value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) - if value.Valid { - maps[key] = value.String - } else { - maps[key] = nil - } - - default: - if id := ind.FieldByName(camelString(key)); id.IsValid() { - o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) - } - } - - cnt++ - } - - if typ == 1 { - v, _ := container.(*Params) - *v = maps - } - - return cnt, nil -} - -// query data to []map[string]interface -func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query data to [][]interface -func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query data to []interface -func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query all rows into map[string]interface with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } -func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { - return o.queryRowsTo(result, keyCol, valueCol) -} - -// query all rows into struct with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to struct { -// Total int -// Found int -// } -func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { - return o.queryRowsTo(ptrStruct, keyCol, valueCol) -} - -// return prepared raw statement for used in times. -func (o *rawSet) Prepare() (RawPreparer, error) { - return newRawPreparer(o) -} - -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { - o := new(rawSet) - o.query = query - o.args = args - o.orm = orm - return o -} diff --git a/orm/orm_test.go b/orm/orm_test.go deleted file mode 100644 index eac7b33a..00000000 --- a/orm/orm_test.go +++ /dev/null @@ -1,2500 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -package orm - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -var _ = os.PathSeparator - -var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" -) - -type argAny []interface{} - -// get interface by index from interface slice -func (a argAny) Get(i int, args ...interface{}) (r interface{}) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { - if len(args) == 0 { - return false, fmt.Errorf("miss args") - } - b := args[0] - arg := argAny(args) - - switch v := a.(type) { - case reflect.Kind: - ok = reflect.ValueOf(b).Kind() == v - case time.Time: - if v2, vo := b.(time.Time); vo { - if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) - a = v.Format(format) - b = v2.Format(format) - ok = a == b - } else { - err = fmt.Errorf("compare datetime miss format") - goto wrongArg - } - } - default: - ok = ToStr(a) == ToStr(b) - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } - } - -wrongArg: - if err != nil { - return false, err - } - - return true, nil -} - -func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); !ok { - return err - } - return nil -} - -func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); !ok { - return err - } - return nil -} - -func getCaller(skip int) string { - pc, file, line, _ := runtime.Caller(skip) - fun := runtime.FuncForPC(pc) - _, fn := filepath.Split(file) - data, err := ioutil.ReadFile(file) - var codes []string - if err == nil { - lines := bytes.Split(data, []byte{'\n'}) - n := 10 - for i := 0; i < n; i++ { - o := line - n - if o < 0 { - continue - } - cur := o + i + 1 - flag := " " - if cur == line { - flag = ">>" - } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) - if code != "" { - codes = append(codes, code) - } - } - } - funName := fun.Name() - if i := strings.LastIndex(funName, "."); i > -1 { - funName = funName[i+1:] - } - return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) -} - -// Deprecated: Using stretchr/testify/assert -func throwFail(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.Fail() - } -} - -func throwFailNow(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.FailNow() - } -} - -func TestGetDB(t *testing.T) { - if db, err := GetDB(); err != nil { - throwFailNow(t, err) - } else { - err = db.Ping() - throwFailNow(t, err) - } -} - -func TestSyncDb(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - err := RunSyncdb("default", true, Debug) - throwFail(t, err) - - modelCache.clean() -} - -func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - BootStrap() - - dORM = NewOrm() - dDbBaser = getDbAlias("default").DbBaser -} - -func TestModelSyntax(t *testing.T) { - user := &User{} - ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) - throwFail(t, AssertIs(ok, true)) - - mi, ok = modelCache.get("user") - throwFail(t, AssertIs(ok, true)) - if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) - } -} - -var DataValues = map[string]interface{}{ - "Boolean": true, - "Char": "char", - "Text": "text", - "JSON": `{"name":"json"}`, - "Jsonb": `{"name": "jsonb"}`, - "Time": time.Now(), - "Date": time.Now(), - "DateTime": time.Now(), - "Byte": byte(1<<8 - 1), - "Rune": rune(1<<31 - 1), - "Int": int(1<<31 - 1), - "Int8": int8(1<<7 - 1), - "Int16": int16(1<<15 - 1), - "Int32": int32(1<<31 - 1), - "Int64": int64(1<<63 - 1), - "Uint": uint(1<<32 - 1), - "Uint8": uint8(1<<8 - 1), - "Uint16": uint16(1<<16 - 1), - "Uint32": uint32(1<<32 - 1), - "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported - "Float32": float32(100.1234), - "Float64": float64(100.1234), - "Decimal": float64(100.1234), -} - -func TestDataTypes(t *testing.T) { - d := Data{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - if name == "JSON" { - continue - } - e := ind.FieldByName(name) - e.Set(reflect.ValueOf(value)) - } - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = Data{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestNullDataTypes(t *testing.T) { - d := DataNull{} - - if IsPostgres { - // can removed when this fixed - // https://github.com/lib/pq/pull/125 - d.DateTime = time.Now() - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` - d = DataNull{ID: 1, JSON: data} - num, err := dORM.Update(&d) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - d = DataNull{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.JSON, data)) - - throwFail(t, AssertIs(d.NullBool.Valid, false)) - throwFail(t, AssertIs(d.NullString.Valid, false)) - throwFail(t, AssertIs(d.NullInt64.Valid, false)) - throwFail(t, AssertIs(d.NullFloat64.Valid, false)) - - throwFail(t, AssertIs(d.BooleanPtr, nil)) - throwFail(t, AssertIs(d.CharPtr, nil)) - throwFail(t, AssertIs(d.TextPtr, nil)) - throwFail(t, AssertIs(d.BytePtr, nil)) - throwFail(t, AssertIs(d.RunePtr, nil)) - throwFail(t, AssertIs(d.IntPtr, nil)) - throwFail(t, AssertIs(d.Int8Ptr, nil)) - throwFail(t, AssertIs(d.Int16Ptr, nil)) - throwFail(t, AssertIs(d.Int32Ptr, nil)) - throwFail(t, AssertIs(d.Int64Ptr, nil)) - throwFail(t, AssertIs(d.UintPtr, nil)) - throwFail(t, AssertIs(d.Uint8Ptr, nil)) - throwFail(t, AssertIs(d.Uint16Ptr, nil)) - throwFail(t, AssertIs(d.Uint32Ptr, nil)) - throwFail(t, AssertIs(d.Uint64Ptr, nil)) - throwFail(t, AssertIs(d.Float32Ptr, nil)) - throwFail(t, AssertIs(d.Float64Ptr, nil)) - throwFail(t, AssertIs(d.DecimalPtr, nil)) - throwFail(t, AssertIs(d.TimePtr, nil)) - throwFail(t, AssertIs(d.DatePtr, nil)) - throwFail(t, AssertIs(d.DateTimePtr, nil)) - - _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() - throwFail(t, err) - - d = DataNull{ID: 2} - err = dORM.Read(&d) - throwFail(t, err) - - booleanPtr := true - charPtr := string("test") - textPtr := string("test") - bytePtr := byte('t') - runePtr := rune('t') - intPtr := int(42) - int8Ptr := int8(42) - int16Ptr := int16(42) - int32Ptr := int32(42) - int64Ptr := int64(42) - uintPtr := uint(42) - uint8Ptr := uint8(42) - uint16Ptr := uint16(42) - uint32Ptr := uint32(42) - uint64Ptr := uint64(42) - float32Ptr := float32(42.0) - float64Ptr := float64(42.0) - decimalPtr := float64(42.0) - timePtr := time.Now() - datePtr := time.Now() - dateTimePtr := time.Now() - - d = DataNull{ - DateTime: time.Now(), - NullString: sql.NullString{String: "test", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - BooleanPtr: &booleanPtr, - CharPtr: &charPtr, - TextPtr: &textPtr, - BytePtr: &bytePtr, - RunePtr: &runePtr, - IntPtr: &intPtr, - Int8Ptr: &int8Ptr, - Int16Ptr: &int16Ptr, - Int32Ptr: &int32Ptr, - Int64Ptr: &int64Ptr, - UintPtr: &uintPtr, - Uint8Ptr: &uint8Ptr, - Uint16Ptr: &uint16Ptr, - Uint32Ptr: &uint32Ptr, - Uint64Ptr: &uint64Ptr, - Float32Ptr: &float32Ptr, - Float64Ptr: &float64Ptr, - DecimalPtr: &decimalPtr, - TimePtr: &timePtr, - DatePtr: &datePtr, - DateTimePtr: &dateTimePtr, - } - - id, err = dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - d = DataNull{ID: 3} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.NullBool.Valid, true)) - throwFail(t, AssertIs(d.NullBool.Bool, true)) - - throwFail(t, AssertIs(d.NullString.Valid, true)) - throwFail(t, AssertIs(d.NullString.String, "test")) - - throwFail(t, AssertIs(d.NullInt64.Valid, true)) - throwFail(t, AssertIs(d.NullInt64.Int64, 42)) - - throwFail(t, AssertIs(d.NullFloat64.Valid, true)) - throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) - - throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) - throwFail(t, AssertIs(*d.CharPtr, charPtr)) - throwFail(t, AssertIs(*d.TextPtr, textPtr)) - throwFail(t, AssertIs(*d.BytePtr, bytePtr)) - throwFail(t, AssertIs(*d.RunePtr, runePtr)) - throwFail(t, AssertIs(*d.IntPtr, intPtr)) - throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) - throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) - throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) - throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) - throwFail(t, AssertIs(*d.UintPtr, uintPtr)) - throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) - throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) - throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) - throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) - throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) - throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) - throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - - // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() - assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) - assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) - assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) - - // test support for pointer fields using RawSeter.QueryRows() - var dnList []*DataNull - Q := dDbBaser.TableQuote() - num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - equal := reflect.DeepEqual(*dnList[0], d) - throwFailNow(t, AssertIs(equal, true)) -} - -func TestDataCustomTypes(t *testing.T) { - d := DataCustom{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - e.Set(reflect.ValueOf(value).Convert(e.Type())) - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = DataCustom{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - vu := e.Interface() - value = reflect.ValueOf(value).Convert(e.Type()).Interface() - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestCRUD(t *testing.T) { - profile := NewProfile() - profile.Age = 30 - profile.Money = 1234.12 - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 3 - user.IsStaff = true - user.IsActive = true - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - u := &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - - throwFail(t, AssertIs(u.UserName, "slene")) - throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) - throwFail(t, AssertIs(u.Password, "pass")) - throwFail(t, AssertIs(u.Status, 3)) - throwFail(t, AssertIs(u.IsStaff, true)) - throwFail(t, AssertIs(u.IsActive, true)) - - assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) - assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) - - user.UserName = "astaxie" - user.Profile = profile - num, err := dORM.Update(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "astaxie")) - throwFail(t, AssertIs(u.Profile.ID, profile.ID)) - - u = &User{UserName: "astaxie", Password: "pass"} - err = dORM.Read(u, "UserName") - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, 1)) - - u.UserName = "QQ" - u.Password = "111" - num, err = dORM.Update(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "QQ")) - throwFail(t, AssertIs(u.Password, "pass")) - - num, err = dORM.Delete(profile) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - throwFail(t, AssertIs(true, u.Profile == nil)) - - num, err = dORM.Delete(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: 100} - err = dORM.Read(u) - throwFail(t, AssertIs(err, ErrNoRows)) - - ub := UserBig{} - ub.Name = "name" - id, err = dORM.Insert(&ub) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - ub = UserBig{ID: 1} - err = dORM.Read(&ub) - throwFail(t, err) - throwFail(t, AssertIs(ub.Name, "name")) - - num, err = dORM.Delete(&ub, "name") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertTestData(t *testing.T) { - var users []*User - - profile := NewProfile() - profile.Age = 28 - profile.Money = 1234.12 - - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 1 - user.IsStaff = false - user.IsActive = true - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - profile = NewProfile() - profile.Age = 30 - profile.Money = 4321.09 - - id, err = dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "astaxie" - user.Email = "astaxie@gmail.com" - user.Password = "password" - user.Status = 2 - user.IsStaff = true - user.IsActive = false - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "nobody" - user.Email = "nobody@gmail.com" - user.Password = "nobody" - user.Status = 3 - user.IsStaff = false - user.IsActive = false - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 4)) - - tags := []*Tag{ - {Name: "golang", BestPost: &Post{ID: 2}}, - {Name: "example"}, - {Name: "format"}, - {Name: "c++"}, - } - - posts := []*Post{ - {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. -This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. -With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, - {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. -The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, - } - - comments := []*Comment{ - {Post: posts[0], Content: "a comment"}, - {Post: posts[1], Content: "yes"}, - {Post: posts[1]}, - {Post: posts[1]}, - {Post: posts[2]}, - {Post: posts[2]}, - } - - for _, tag := range tags { - id, err := dORM.Insert(tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, post := range posts { - id, err := dORM.Insert(post) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(post.Tags) - if num > 0 { - nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - - for _, comment := range comments { - id, err := dORM.Insert(comment) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - permissions := []*Permission{ - {Name: "writePosts"}, - {Name: "readComments"}, - {Name: "readPosts"}, - } - - groups := []*Group{ - { - Name: "admins", - Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, - }, - { - Name: "users", - Permissions: []*Permission{permissions[1], permissions[2]}, - }, - } - - for _, permission := range permissions { - id, err := dORM.Insert(permission) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, group := range groups { - _, err := dORM.Insert(group) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(group.Permissions) - if num > 0 { - nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - -} - -func TestCustomField(t *testing.T) { - user := User{ID: 2} - err := dORM.Read(&user) - throwFailNow(t, err) - - user.Langs = append(user.Langs, "zh-CN", "en-US") - user.Extra.Name = "beego" - user.Extra.Data = "orm" - _, err = dORM.Update(&user, "Langs", "Extra") - throwFailNow(t, err) - - user = User{ID: 2} - err = dORM.Read(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(len(user.Langs), 2)) - throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) - throwFailNow(t, AssertIs(user.Langs[1], "en-US")) - - throwFailNow(t, AssertIs(user.Extra.Name, "beego")) - throwFailNow(t, AssertIs(user.Extra.Data, "orm")) -} - -func TestExpr(t *testing.T) { - user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") - num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("created", time.Now()).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() - // throwFail(t, err) - // throwFail(t, AssertIs(num, 3)) -} - -func TestOperators(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", String("slene")).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__iexact", "Slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__contains", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - var shouldNum int - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__contains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gt", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gte", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("status__lt", Uint(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__lte", Int(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("user_name__startswith", "s").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - if IsSqlite || IsTidb { - shouldNum = 1 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__startswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__istartswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__endswith", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__endswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__iendswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("status__in", 1, 2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__in", []int{1, 2}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - n1, n2 := 1, 2 - num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", 2, 3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", []int{2, 3}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("user_name", "= 'slene'").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.FilterRaw("status", "IN (1, 2)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSetCond(t *testing.T) { - cond := NewCondition() - cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - - qs := dORM.QueryTable("user") - num, err := qs.SetCond(cond1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond3 := cond.AndNotCond(cond.And("status__in", 1)) - num, err = qs.SetCond(cond3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond4).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond5).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) -} - -func TestLimit(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Limit(-1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - num, err = qs.Limit(-1, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Limit(0, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOffset(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOrderBy(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestAll(t *testing.T) { - var users []*User - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("Id").All(&users) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFail(t, AssertIs(users[0].UserName, "slene")) - throwFail(t, AssertIs(users[1].UserName, "astaxie")) - throwFail(t, AssertIs(users[2].UserName, "nobody")) - - var users2 []User - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").All(&users2) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(users2), 3)) - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - throwFailNow(t, AssertIs(users2[0].ID, 0)) - throwFailNow(t, AssertIs(users2[1].ID, 0)) - throwFailNow(t, AssertIs(users2[2].ID, 0)) - throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - var users3 []*User - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - throwFailNow(t, AssertIs(users3 == nil, false)) -} - -func TestOne(t *testing.T) { - var user User - qs := dORM.QueryTable("user") - err := qs.One(&user) - throwFail(t, err) - - user = User{} - err = qs.OrderBy("Id").Limit(1).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "slene")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - user = User{} - err = qs.OrderBy("-Id").Limit(100).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "nobody")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - err = qs.Filter("user_name", "nothing").One(&user) - throwFail(t, AssertIs(err, ErrNoRows)) - -} - -func TestValues(t *testing.T) { - var maps []Params - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[2]["Profile"], nil)) - } - - num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) - throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) - } - - num, err = qs.Filter("UserName", "slene").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestValuesList(t *testing.T) { - var list []ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").ValuesList(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) - } - - num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][0], "slene")) - throwFail(t, AssertIs(list[0][1], 28)) - throwFail(t, AssertIs(list[2][1], nil)) - } -} - -func TestValuesFlat(t *testing.T) { - var list ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "slene")) - throwFail(t, AssertIs(list[1], "astaxie")) - throwFail(t, AssertIs(list[2], "nobody")) - } -} - -func TestRelatedSel(t *testing.T) { - if IsTidb { - // Skip it. TiDB does not support relation now. - return - } - qs := dORM.QueryTable("user") - num, err := qs.Filter("profile__age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var user User - err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "slene").RelatedSel().One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(user.Profile, nil)) - - qs = dORM.QueryTable("user_profile") - num, err = qs.Filter("user__username", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var posts []*Post - qs = dORM.QueryTable("post") - num, err = qs.RelatedSel().All(&posts) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 4)) - - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) -} - -func TestReverseQuery(t *testing.T) { - var profile Profile - err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - profile = Profile{} - err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - var user User - err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - var posts []*Post - num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). - Filter("User__UserName", "slene").RelatedSel().All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].User == nil, false)) - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - - var tags []*Tag - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) - throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) -} - -func TestLoadRelated(t *testing.T) { - // load reverse foreign key - user := User{ID: 3} - - err := dORM.Read(&user) - throwFailNow(t, err) - - num, err := dORM.LoadRelated(&user, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - - num, err = dORM.LoadRelated(&user, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - // load reverse one to one - profile := Profile{ID: 3} - profile.BestPost = &Post{ID: 2} - num, err = dORM.Update(&profile, "BestPost") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - err = dORM.Read(&profile) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&profile, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&profile, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) - - // load rel one to one - err = dORM.Read(&user) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&user, "Profile") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - num, err = dORM.LoadRelated(&user, "Profile", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) - throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) - - post := Post{ID: 2} - - // load rel foreign key - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&post, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(post.User.Profile == nil, false)) - throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) - - // load rel m2m - post = Post{ID: 2} - - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "Tags") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - - num, err = dORM.LoadRelated(&post, "Tags", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) - - // load reverse m2m - tag := Tag{ID: 1} - - err = dORM.Read(&tag) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&tag, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - - num, err = dORM.LoadRelated(&tag, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) -} - -func TestQueryM2M(t *testing.T) { - post := Post{ID: 4} - m2m := dORM.QueryM2M(&post, "Tags") - - tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} - tag2 := &Tag{Name: "TestTag3"} - tag3 := []interface{}{&Tag{Name: "TestTag4"}} - - tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} - - for _, tag := range tags { - _, err := dORM.Insert(tag) - throwFailNow(t, err) - } - - num, err := m2m.Add(tag1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 5)) - - num, err = m2m.Remove(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - exist := m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - tag := Tag{Name: "test"} - _, err = dORM.Insert(&tag) - throwFailNow(t, err) - - m2m = dORM.QueryM2M(&tag, "Posts") - - post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} - post2 := &Post{Title: "TestPost3"} - post3 := []interface{}{&Post{Title: "TestPost4"}} - - posts := []interface{}{post1[0], post1[1], post2, post3[0]} - - for _, post := range posts { - p := post.(*Post) - p.User = &User{ID: 1} - _, err := dORM.Insert(post) - throwFailNow(t, err) - } - - num, err = m2m.Add(post1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - num, err = m2m.Remove(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - num, err = dORM.Delete(&tag) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) -} - -func TestQueryRelate(t *testing.T) { - // post := &Post{Id: 2} - - // qs := dORM.QueryRelate(post, "Tags") - // num, err := qs.Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - - // var tags []*Tag - // num, err = qs.All(&tags) - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - // throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) -} - -func TestPkManyRelated(t *testing.T) { - permission := &Permission{Name: "readPosts"} - err := dORM.Read(permission, "Name") - throwFailNow(t, err) - - var groups []*Group - qs := dORM.QueryTable("Group") - num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) -} - -func TestPrepareInsert(t *testing.T) { - qs := dORM.QueryTable("user") - i, err := qs.PrepareInsert() - throwFailNow(t, err) - - var user User - user.UserName = "testing1" - num, err := i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - user.UserName = "testing2" - num, err = i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - err = i.Close() - throwFail(t, err) - err = i.Close() - throwFail(t, AssertIs(err, ErrStmtClosed)) -} - -func TestRawExec(t *testing.T) { - Q := dDbBaser.TableQuote() - - query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) - res, err := dORM.Raw(query, "testing", "slene").Exec() - throwFail(t, err) - num, err := res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) - - res, err = dORM.Raw(query, "slene", "testing").Exec() - throwFail(t, err) - num, err = res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) -} - -func TestRawQueryRow(t *testing.T) { - var ( - Boolean bool - Char string - Text string - Time time.Time - Date time.Time - DateTime time.Time - Byte byte - Rune rune - Int int - Int8 int - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 - ) - - dataValues := make(map[string]interface{}, len(DataValues)) - - for k, v := range DataValues { - dataValues[strings.ToLower(k)] = v - } - - Q := dDbBaser.TableQuote() - - cols := []string{ - "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - var id int - values := []interface{}{ - &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, - &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, - } - err := dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - for i, col := range cols { - vu := values[i] - v := reflect.ValueOf(vu).Elem().Interface() - switch col { - case "id": - throwFail(t, AssertIs(id, 1)) - case "time": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) - case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) - default: - throwFail(t, AssertIs(v, dataValues[col])) - } - } - - var ( - uid int - status *int - pid *int - ) - - cols = []string{ - "id", "Status", "profile_id", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) - err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) - throwFail(t, err) - throwFail(t, AssertIs(uid, 4)) - throwFail(t, AssertIs(*status, 3)) - throwFail(t, AssertIs(pid, nil)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nd *DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - err = dORM.Raw(query, newId).QueryRow(&nd) - throwFailNow(t, err) - - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -// user_profile table -type userProfile struct { - User - Age int - Money float64 -} - -func TestQueryRows(t *testing.T) { - Q := dDbBaser.TableQuote() - - var datas []*Data - - query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas), 1)) - - ind := reflect.Indirect(reflect.ValueOf(datas[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var datas2 []Data - - query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err = dORM.Raw(query).QueryRows(&datas2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas2), 1)) - - ind = reflect.Indirect(reflect.ValueOf(datas2[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var ids []int - var usernames []string - query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &usernames) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(ids), 3)) - throwFailNow(t, AssertIs(ids[0], 2)) - throwFailNow(t, AssertIs(usernames[0], "slene")) - throwFailNow(t, AssertIs(ids[1], 3)) - throwFailNow(t, AssertIs(usernames[1], "astaxie")) - throwFailNow(t, AssertIs(ids[2], 4)) - throwFailNow(t, AssertIs(usernames[2], "nobody")) - - // test query rows by nested struct - var l []userProfile - query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&l) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(l), 2)) - throwFailNow(t, AssertIs(l[0].UserName, "slene")) - throwFailNow(t, AssertIs(l[0].Age, 28)) - throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(l[1].Age, 30)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nDataList []*DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - num, err = dORM.Raw(query, newId).QueryRows(&nDataList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - nd := nDataList[0] - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -func TestRawValues(t *testing.T) { - Q := dDbBaser.TableQuote() - - var maps []Params - query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) - num, err := dORM.Raw(query, 1).Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(maps[0]["user_name"], "slene")) - } - - var lists []ParamsList - num, err = dORM.Raw(query, 1).ValuesList(&lists) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(lists[0][0], "slene")) - } - - query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) - var list ParamsList - num, err = dORM.Raw(query).ValuesFlat(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "2")) - throwFail(t, AssertIs(list[1], "3")) - throwFail(t, AssertIs(list[2], nil)) - } -} - -func TestRawPrepare(t *testing.T) { - switch { - case IsMysql || IsSqlite: - - pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - throwFail(t, err) - if pre != nil { - r, err := pre.Exec("name1") - throwFail(t, err) - - tid, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(tid > 0, true)) - - r, err = pre.Exec("name2") - throwFail(t, err) - - id, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+1)) - - r, err = pre.Exec("name3") - throwFail(t, err) - - id, err = r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+2)) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - - case IsPostgres: - - pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() - throwFail(t, err) - if pre != nil { - _, err := pre.Exec("name1") - throwFail(t, err) - - _, err = pre.Exec("name2") - throwFail(t, err) - - _, err = pre.Exec("name3") - throwFail(t, err) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - if err == nil { - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - } - } -} - -func TestUpdate(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ - "is_staff": true, - "is_active": true, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - // with join - num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ - "is_staff": false, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColAdd, 100), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMinus, 50), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMultiply, 3), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColExcept, 5), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - user := User{UserName: "slene"} - err = dORM.Read(&user, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(user.Nums, 30)) -} - -func TestDelete(t *testing.T) { - qs := dORM.QueryTable("user_profile") - num, err := qs.Filter("user__user_name", "slene").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 6)) - - qs = dORM.QueryTable("post") - num, err = qs.Filter("Id", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - qs = dORM.QueryTable("comment") - num, err = qs.Filter("Post__User", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestTransaction(t *testing.T) { - // this test worked when database support transaction - - o := NewOrm() - err := o.Begin() - throwFail(t, err) - - var names = []string{"1", "2", "3"} - - var tag Tag - tag.Name = names[0] - id, err := o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - switch { - case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() - throwFail(t, err) - if err == nil { - id, err = res.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - } - - err = o.Rollback() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - err = o.Begin() - throwFail(t, err) - - tag.Name = "commit" - id, err = o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - o.Commit() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - -} - -func TestTransactionIsolationLevel(t *testing.T) { - // this test worked when database support transaction isolation level - if IsSqlite { - return - } - - o1 := NewOrm() - o2 := NewOrm() - - // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - - // o1 insert tag - var tag Tag - tag.Name = "test-transaction" - id, err := o1.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o1 commit - o1.Commit() - - // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o2 commit and query tag table, get the result - o2.Commit() - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestBeginTxWithContextCanceled(t *testing.T) { - o := NewOrm() - ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // cancel the context before commit to make it error - cancel() - err = o.Commit() - throwFail(t, AssertIs(err, context.Canceled)) -} - -func TestReadOrCreate(t *testing.T) { - u := &User{ - UserName: "Kyle", - Email: "kylemcc@gmail.com", - Password: "other_pass", - Status: 7, - IsStaff: false, - IsActive: true, - } - - created, pk, err := dORM.ReadOrCreate(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.ID, pk)) - throwFail(t, AssertIs(u.UserName, "Kyle")) - throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) - throwFail(t, AssertIs(u.Password, "other_pass")) - throwFail(t, AssertIs(u.Status, 7)) - throwFail(t, AssertIs(u.IsStaff, false)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) - - nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} - created, pk, err = dORM.ReadOrCreate(nu, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.UserName, u.UserName)) - throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above - throwFail(t, AssertIs(nu.Password, u.Password)) - throwFail(t, AssertIs(nu.Status, u.Status)) - throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) - throwFail(t, AssertIs(nu.IsActive, u.IsActive)) - - dORM.Delete(u) -} - -func TestInLine(t *testing.T) { - name := "inline" - email := "hello@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - il := NewInLine() - il.ID = 1 - err = dORM.Read(il) - throwFail(t, err) - - throwFail(t, AssertIs(il.Name, name)) - throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) -} - -func TestInLineOneToOne(t *testing.T) { - name := "121" - email := "121@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - note := "one2one" - il121 := NewInLineOneToOne() - il121.Note = note - il121.InLine = inline - _, err = dORM.Insert(il121) - throwFail(t, err) - throwFail(t, AssertIs(il121.ID, 1)) - - il := NewInLineOneToOne() - err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) - - throwFail(t, err) - throwFail(t, AssertIs(il.Note, note)) - throwFail(t, AssertIs(il.InLine.ID, id)) - throwFail(t, AssertIs(il.InLine.Name, name)) - throwFail(t, AssertIs(il.InLine.Email, email)) - - rinline := NewInLine() - err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) - - throwFail(t, err) - throwFail(t, AssertIs(rinline.ID, id)) - throwFail(t, AssertIs(rinline.Name, name)) - throwFail(t, AssertIs(rinline.Email, email)) -} - -func TestIntegerPk(t *testing.T) { - its := []IntegerPk{ - {ID: math.MinInt64, Value: "-"}, - {ID: 0, Value: "0"}, - {ID: math.MaxInt64, Value: "+"}, - } - - num, err := dORM.InsertMulti(len(its), its) - throwFail(t, err) - throwFail(t, AssertIs(num, len(its))) - - for _, intPk := range its { - out := IntegerPk{ID: intPk.ID} - err = dORM.Read(&out) - throwFail(t, err) - throwFail(t, AssertIs(out.Value, intPk.Value)) - } - - num, err = dORM.InsertMulti(1, []*IntegerPk{{ - ID: 1, Value: "ok", - }}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertAuto(t *testing.T) { - u := &User{ - UserName: "autoPre", - Email: "autoPre@gmail.com", - } - - id, err := dORM.Insert(u) - throwFail(t, err) - - id += 100 - su := &User{ - ID: int(id), - UserName: "auto", - Email: "auto@gmail.com", - } - - nid, err := dORM.Insert(su) - throwFail(t, err) - throwFail(t, AssertIs(nid, id)) - - users := []User{ - {ID: int(id + 100), UserName: "auto_100"}, - {ID: int(id + 110), UserName: "auto_110"}, - {ID: int(id + 120), UserName: "auto_120"}, - } - num, err := dORM.InsertMulti(100, users) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - u = &User{ - UserName: "auto_121", - } - - nid, err = dORM.Insert(u) - throwFail(t, err) - throwFail(t, AssertIs(nid, id+120+1)) -} - -func TestUintPk(t *testing.T) { - name := "go" - u := &UintPk{ - ID: 8, - Name: name, - } - - created, _, err := dORM.ReadOrCreate(u, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.Name, name)) - - nu := &UintPk{ID: 8} - created, pk, err := dORM.ReadOrCreate(nu, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.Name, name)) - - dORM.Delete(u) -} - -func TestPtrPk(t *testing.T) { - parent := &IntegerPk{ID: 10, Value: "10"} - - id, _ := dORM.Insert(parent) - if !IsMysql { - // MySql does not support last_insert_id in this case: see #2382 - throwFail(t, AssertIs(id, 10)) - } - - ptr := PtrPk{ID: parent, Positive: true} - num, err := dORM.InsertMulti(2, []PtrPk{ptr}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(ptr.ID, parent)) - - nptr := &PtrPk{ID: parent} - created, pk, err := dORM.ReadOrCreate(nptr, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, true)) - - nptr = &PtrPk{Positive: true} - created, pk, err = dORM.ReadOrCreate(nptr, "Positive") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - - nptr.Positive = false - num, err = dORM.Update(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, false)) - - num, err = dORM.Delete(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSnake(t *testing.T) { - cases := map[string]string{ - "i": "i", - "I": "i", - "iD": "i_d", - "ID": "i_d", - "NO": "n_o", - "NOO": "n_o_o", - "NOOooOOoo": "n_o_ooo_o_ooo", - "OrderNO": "order_n_o", - "tagName": "tag_name", - "tag_Name": "tag__name", - "tag_name": "tag_name", - "_tag_name": "_tag_name", - "tag_666name": "tag_666name", - "tag_666Name": "tag_666_name", - } - for name, want := range cases { - got := snakeString(name) - throwFail(t, AssertIs(got, want)) - } -} - -func TestIgnoreCaseTag(t *testing.T) { - type testTagModel struct { - ID int `orm:"pk"` - NOO string `orm:"column(n)"` - Name01 string `orm:"NULL"` - Name02 string `orm:"COLUMN(Name)"` - Name03 string `orm:"Column(name)"` - } - modelCache.clean() - RegisterModel(&testTagModel{}) - info, ok := modelCache.get("test_tag_model") - throwFail(t, AssertIs(ok, true)) - throwFail(t, AssertNot(info, nil)) - if t == nil { - return - } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) -} - -func TestInsertOrUpdate(t *testing.T) { - RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} - dORM.Insert(&user) - test := User{UserName: "unique_username133"} - fmt.Println(dORM.Driver().Name()) - if dORM.Driver().Name() == "sqlite3" { - fmt.Println("sqlite3 is nonsupport") - return - } - // test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) - } - // test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) - } - - // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { - return - } - // test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) - } - // test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) - } - // test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) - } - // test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) - } -} diff --git a/orm/qb.go b/orm/qb.go deleted file mode 100644 index e0655a17..00000000 --- a/orm/qb.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import "errors" - -// QueryBuilder is the Query builder interface -type QueryBuilder interface { - Select(fields ...string) QueryBuilder - ForUpdate() QueryBuilder - From(tables ...string) QueryBuilder - InnerJoin(table string) QueryBuilder - LeftJoin(table string) QueryBuilder - RightJoin(table string) QueryBuilder - On(cond string) QueryBuilder - Where(cond string) QueryBuilder - And(cond string) QueryBuilder - Or(cond string) QueryBuilder - In(vals ...string) QueryBuilder - OrderBy(fields ...string) QueryBuilder - Asc() QueryBuilder - Desc() QueryBuilder - Limit(limit int) QueryBuilder - Offset(offset int) QueryBuilder - GroupBy(fields ...string) QueryBuilder - Having(cond string) QueryBuilder - Update(tables ...string) QueryBuilder - Set(kv ...string) QueryBuilder - Delete(tables ...string) QueryBuilder - InsertInto(table string, fields ...string) QueryBuilder - Values(vals ...string) QueryBuilder - Subquery(sub string, alias string) string - String() string -} - -// NewQueryBuilder return the QueryBuilder -func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { - if driver == "mysql" { - qb = new(MySQLQueryBuilder) - } else if driver == "tidb" { - qb = new(TiDBQueryBuilder) - } else if driver == "postgres" { - err = errors.New("postgres query builder is not supported yet") - } else if driver == "sqlite" { - err = errors.New("sqlite query builder is not supported yet") - } else { - err = errors.New("unknown driver for query builder") - } - return -} diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go deleted file mode 100644 index 23bdc9ee..00000000 --- a/orm/qb_mysql.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "strings" -) - -// CommaSpace is the separation -const CommaSpace = ", " - -// MySQLQueryBuilder is the SQL build -type MySQLQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *MySQLQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *MySQLQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *MySQLQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") -} diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go deleted file mode 100644 index 87b3ae84..00000000 --- a/orm/qb_tidb.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2015 TiDB Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "strings" -) - -// TiDBQueryBuilder is the SQL build -type TiDBQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *TiDBQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *TiDBQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *TiDBQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") -} diff --git a/orm/types.go b/orm/types.go deleted file mode 100644 index 75af7149..00000000 --- a/orm/types.go +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "reflect" - "time" -) - -// Driver define database driver - -type Driver interface { - Name() string - Type() DriverType -} - -// Fielder define field info -type Fielder interface { - String() string - FieldType() int - SetRaw(interface{}) error - RawValue() interface{} -} - -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) - // insert model data to database - // for example: - // user := new(User) - // id, err = Ormer.Insert(user) - // user must be a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) - // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") - // if colu type is integer : can use(+-*/), string : convert(colu,"value") - // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") - // if colu type is integer : can use(+-*/), string : colu || "value" - InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) - // insert some models to database - InsertMulti(bulk int, mds interface{}) (int64, error) - // update model to database. - // cols set the columns those want to update. - // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns - // for example: - // user := User{Id: 2} - // user.Langs = append(user.Langs, "zh-CN", "en-US") - // user.Extra.Name = "beego" - // user.Extra.Data = "orm" - // num, err = Ormer.Update(&user, "Langs", "Extra") - Update(md interface{}, cols ...string) (int64, error) - // delete model in database - Delete(md interface{}, cols ...string) (int64, error) - // load related models to md model. - // args are limit, offset int and order string. - // - // example: - // Ormer.LoadRelated(post,"Tags") - // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" - // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) - // create a models to models queryer - // for example: - // post := Post{Id: 4} - // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer - // return a QuerySeter for table operations. - // table name can be string or struct. - // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), - QueryTable(ptrStructOrTableName interface{}) QuerySeter - // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver - DBStats() *sql.DBStats -} - -// Inserter insert prepared statement -type Inserter interface { - Insert(interface{}) (int64, error) - Close() error -} - -// QuerySeter query seter -type QuerySeter interface { - // add condition expression to QuerySeter. - // for example: - // filter by UserName == 'slene' - // qs.Filter("UserName", "slene") - // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 - // Filter("profile__Age", 28) - // // time compare - // qs.Filter("created", time.Now()) - Filter(string, ...interface{}) QuerySeter - // add raw sql to querySeter. - // for example: - // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") - // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) - FilterRaw(string, string) QuerySeter - // add NOT condition to querySeter. - // have the same usage as Filter - Exclude(string, ...interface{}) QuerySeter - // set condition to QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond1).Count() - SetCond(*Condition) QuerySeter - // get condition from QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) - // qs = qs.SetCond(cond) - // cond = qs.GetCond() - // cond := cond.Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond).Count() - GetCond() *Condition - // add LIMIT value. - // args[0] means offset, e.g. LIMIT num,offset. - // if Limit <= 0 then Limit will be set to default limit ,eg 1000 - // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 - // for example: - // qs.Limit(10, 2) - // // sql-> limit 10 offset 2 - Limit(limit interface{}, args ...interface{}) QuerySeter - // add OFFSET value - // same as Limit function's args[0] - Offset(offset interface{}) QuerySeter - // add GROUP BY expression - // for example: - // qs.GroupBy("id") - GroupBy(exprs ...string) QuerySeter - // add ORDER expression. - // "column" means ASC, "-column" means DESC. - // for example: - // qs.OrderBy("-status") - OrderBy(exprs ...string) QuerySeter - // set relation model to query together. - // it will query relation models and assign to parent model. - // for example: - // // will load all related fields use left join . - // qs.RelatedSel().One(&user) - // // will load related field only profile - // qs.RelatedSel("profile").One(&user) - // user.Profile.Age = 32 - RelatedSel(params ...interface{}) QuerySeter - // Set Distinct - // for example: - // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). - // Distinct(). - // All(&permissions) - Distinct() QuerySeter - // set FOR UPDATE to query. - // for example: - // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) - ForUpdate() QuerySeter - // return QuerySeter execution result number - // for example: - // num, err = qs.Filter("profile__age__gt", 28).Count() - Count() (int64, error) - // check result empty or not after QuerySeter executed - // the same as QuerySeter.Count > 0 - Exist() bool - // execute update with parameters - // for example: - // num, err = qs.Filter("user_name", "slene").Update(Params{ - // "Nums": ColValue(Col_Minus, 50), - // }) // user slene's Nums will minus 50 - // num, err = qs.Filter("UserName", "slene").Update(Params{ - // "user_name": "slene2" - // }) // user slene's name will change to slene2 - Update(values Params) (int64, error) - // delete from table - //for example: - // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - // //delete two user who's name is testing1 or testing2 - Delete() (int64, error) - // return a insert queryer. - // it can be used in times. - // example: - // i,err := sq.PrepareInsert() - // num, err = i.Insert(&user1) // user table will add one record user1 at once - // num, err = i.Insert(&user2) // user table will add one record user2 at once - // err = i.Close() //don't forget call Close - PrepareInsert() (Inserter, error) - // query all data and map to containers. - // cols means the columns when querying. - // for example: - // var users []*User - // qs.All(&users) // users[0],users[1],users[2] ... - All(container interface{}, cols ...string) (int64, error) - // query one row data and map to containers. - // cols means the columns when querying. - // for example: - // var user User - // qs.One(&user) //user.UserName == "slene" - One(container interface{}, cols ...string) error - // query all data and map to []map[string]interface. - // expres means condition expression. - // it converts data to []map[column]value. - // for example: - // var maps []Params - // qs.Values(&maps) //maps[0]["UserName"]=="slene" - Values(results *[]Params, exprs ...string) (int64, error) - // query all data and map to [][]interface - // it converts data to [][column_index]value - // for example: - // var list []ParamsList - // qs.ValuesList(&list) // list[0][1] == "slene" - ValuesList(results *[]ParamsList, exprs ...string) (int64, error) - // query all data and map to []interface. - // it's designed for one column record set, auto change to []value, not [][column]value. - // for example: - // var list ParamsList - // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" - ValuesFlat(result *ParamsList, expr string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) -} - -// QueryM2Mer model to model query struct -// all operations are on the m2m table only, will not affect the origin model table -type QueryM2Mer interface { - // add models to origin models when creating queryM2M. - // example: - // m2m := orm.QueryM2M(post,"Tag") - // m2m.Add(&Tag1{},&Tag2{}) - // for _,tag := range post.Tags{}{ ... } - // param could also be any of the follow - // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} - // &Tag{Id:5,Name: "TestTag3"} - // []interface{}{&Tag{Id:6,Name: "TestTag4"}} - // insert one or more rows to m2m table - // make sure the relation is defined in post model struct tag. - Add(...interface{}) (int64, error) - // remove models following the origin model relationship - // only delete rows from m2m table - // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) - Remove(...interface{}) (int64, error) - // check model is existed in relationship of origin model - Exist(interface{}) bool - // clean all models in related of origin model - Clear() (int64, error) - // count all related models of origin model - Count() (int64, error) -} - -// RawPreparer raw query statement -type RawPreparer interface { - Exec(...interface{}) (sql.Result, error) - Close() error -} - -// RawSeter raw query seter -// create From Ormer.Raw -// for example: -// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) -// rs := Ormer.Raw(sql, 1) -type RawSeter interface { - //execute sql and get result - Exec() (sql.Result, error) - //query data and map to container - //for example: - // var name string - // var id int - // rs.QueryRow(&id,&name) // id==2 name=="slene" - QueryRow(containers ...interface{}) error - - // query data rows and map to container - // var ids []int - // var names []int - // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) - // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} - QueryRows(containers ...interface{}) (int64, error) - SetArgs(...interface{}) RawSeter - // query data to []map[string]interface - // see QuerySeter's Values - Values(container *[]Params, cols ...string) (int64, error) - // query data to [][]interface - // see QuerySeter's ValuesList - ValuesList(container *[]ParamsList, cols ...string) (int64, error) - // query data to []interface - // see QuerySeter's ValuesFlat - ValuesFlat(container *ParamsList, cols ...string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) - - // return prepared raw statement for used in times. - // for example: - // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) - Prepare() (RawPreparer, error) -} - -// stmtQuerier statement querier -type stmtQuerier interface { - Close() error - Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) - Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) - QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row -} - -// db querier -type dbQuerier interface { - Prepare(query string) (*sql.Stmt, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - Exec(query string, args ...interface{}) (sql.Result, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row -} - -// type DB interface { -// Begin() (*sql.Tx, error) -// Prepare(query string) (stmtQuerier, error) -// Exec(query string, args ...interface{}) (sql.Result, error) -// Query(query string, args ...interface{}) (*sql.Rows, error) -// QueryRow(query string, args ...interface{}) *sql.Row -// } - -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - -// base database struct -type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - SupportUpdateJoin() bool - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - OperatorSQL(string) string - GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) - GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) - MaxLimit() uint64 - TableQuote() string - ReplaceMarks(*string) - HasReturningID(*modelInfo, *string) bool - TimeFromDB(*time.Time, *time.Location) - TimeToDB(*time.Time, *time.Location) - DbTypes() map[string]string - GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) - ShowTablesQuery() string - ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool - collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error -} diff --git a/orm/utils.go b/orm/utils.go deleted file mode 100644 index 3ff76772..00000000 --- a/orm/utils.go +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "math/big" - "reflect" - "strconv" - "strings" - "time" -) - -type fn func(string) string - -var ( - nameStrategyMap = map[string]fn{ - defaultNameStrategy: snakeString, - SnakeAcronymNameStrategy: snakeStringWithAcronym, - } - defaultNameStrategy = "snakeString" - SnakeAcronymNameStrategy = "snakeStringWithAcronym" - nameStrategy = defaultNameStrategy -) - -// StrTo is the target string -type StrTo string - -// Set string -func (f *StrTo) Set(v string) { - if v != "" { - *f = StrTo(v) - } else { - f.Clear() - } -} - -// Clear string -func (f *StrTo) Clear() { - *f = StrTo(0x1E) -} - -// Exist check string exist -func (f StrTo) Exist() bool { - return string(f) != string(0x1E) -} - -// Bool string to bool -func (f StrTo) Bool() (bool, error) { - return strconv.ParseBool(f.String()) -} - -// Float32 string to float32 -func (f StrTo) Float32() (float32, error) { - v, err := strconv.ParseFloat(f.String(), 32) - return float32(v), err -} - -// Float64 string to float64 -func (f StrTo) Float64() (float64, error) { - return strconv.ParseFloat(f.String(), 64) -} - -// Int string to int -func (f StrTo) Int() (int, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int(v), err -} - -// Int8 string to int8 -func (f StrTo) Int8() (int8, error) { - v, err := strconv.ParseInt(f.String(), 10, 8) - return int8(v), err -} - -// Int16 string to int16 -func (f StrTo) Int16() (int16, error) { - v, err := strconv.ParseInt(f.String(), 10, 16) - return int16(v), err -} - -// Int32 string to int32 -func (f StrTo) Int32() (int32, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int32(v), err -} - -// Int64 string to int64 -func (f StrTo) Int64() (int64, error) { - v, err := strconv.ParseInt(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) // octal - if !ok { - return v, err - } - return ni.Int64(), nil - } - return v, err -} - -// Uint string to uint -func (f StrTo) Uint() (uint, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint(v), err -} - -// Uint8 string to uint8 -func (f StrTo) Uint8() (uint8, error) { - v, err := strconv.ParseUint(f.String(), 10, 8) - return uint8(v), err -} - -// Uint16 string to uint16 -func (f StrTo) Uint16() (uint16, error) { - v, err := strconv.ParseUint(f.String(), 10, 16) - return uint16(v), err -} - -// Uint32 string to uint32 -func (f StrTo) Uint32() (uint32, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint32(v), err -} - -// Uint64 string to uint64 -func (f StrTo) Uint64() (uint64, error) { - v, err := strconv.ParseUint(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) - if !ok { - return v, err - } - return ni.Uint64(), nil - } - return v, err -} - -// String string to string -func (f StrTo) String() string { - if f.Exist() { - return string(f) - } - return "" -} - -// ToStr interface to string -func ToStr(value interface{}, args ...int) (s string) { - switch v := value.(type) { - case bool: - s = strconv.FormatBool(v) - case float32: - s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) - case float64: - s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) - case int: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int8: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int16: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int32: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int64: - s = strconv.FormatInt(v, argInt(args).Get(0, 10)) - case uint: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint8: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint16: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint32: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint64: - s = strconv.FormatUint(v, argInt(args).Get(0, 10)) - case string: - s = v - case []byte: - s = string(v) - default: - s = fmt.Sprintf("%v", v) - } - return s -} - -// ToInt64 interface to int64 -func ToInt64(value interface{}) (d int64) { - val := reflect.ValueOf(value) - switch value.(type) { - case int, int8, int16, int32, int64: - d = val.Int() - case uint, uint8, uint16, uint32, uint64: - d = int64(val.Uint()) - default: - panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) - } - return -} - -func snakeStringWithAcronym(s string) string { - data := make([]byte, 0, len(s)*2) - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - before := false - after := false - if i > 0 { - before = s[i-1] >= 'a' && s[i-1] <= 'z' - } - if i+1 < num { - after = s[i+1] >= 'a' && s[i+1] <= 'z' - } - if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { - data = append(data, '_') - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// snake string, XxYy to xx_yy , XxYY to xx_y_y -func snakeString(s string) string { - data := make([]byte, 0, len(s)*2) - j := false - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - if i > 0 && d >= 'A' && d <= 'Z' && j { - data = append(data, '_') - } - if d != '_' { - j = true - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// SetNameStrategy set different name strategy -func SetNameStrategy(s string) { - if SnakeAcronymNameStrategy != s { - nameStrategy = defaultNameStrategy - } - nameStrategy = s -} - -// camel string, xx_yy to XxYy -func camelString(s string) string { - data := make([]byte, 0, len(s)) - flag, num := true, len(s)-1 - for i := 0; i <= num; i++ { - d := s[i] - if d == '_' { - flag = true - continue - } else if flag { - if d >= 'a' && d <= 'z' { - d = d - 32 - } - flag = false - } - data = append(data, d) - } - return string(data[:]) -} - -type argString []string - -// get string by index from string slice -func (a argString) Get(i int, args ...string) (r string) { - if i >= 0 && i < len(a) { - r = a[i] - } else if len(args) > 0 { - r = args[0] - } - return -} - -type argInt []int - -// get int by index from int slice -func (a argInt) Get(i int, args ...int) (r int) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -// parse time to string with location -func timeParse(dateString, format string) (time.Time, error) { - tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) - return tp, err -} - -// get pointer indirect type -func indirectType(v reflect.Type) reflect.Type { - switch v.Kind() { - case reflect.Ptr: - return indirectType(v.Elem()) - default: - return v - } -} diff --git a/orm/utils_test.go b/orm/utils_test.go deleted file mode 100644 index 7d94cada..00000000 --- a/orm/utils_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "testing" -) - -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} - - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } - - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeString(t *testing.T) { - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeStringWithAcronym(t *testing.T) { - camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} diff --git a/parser.go b/parser.go deleted file mode 100644 index 3a311894..00000000 --- a/parser.go +++ /dev/null @@ -1,591 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "encoding/json" - "errors" - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "unicode" - - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var globalRouterTemplate = `package {{.routersDir}} - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param"{{.globalimport}} -) - -func init() { -{{.globalinfo}} -} -` - -var ( - lastupdateFilename = "lastupdate.tmp" - commentFilename string - pkgLastupdate map[string]int64 - genInfoList map[string][]ControllerComments - - routerHooks = map[string]int{ - "beego.BeforeStatic": BeforeStatic, - "beego.BeforeRouter": BeforeRouter, - "beego.BeforeExec": BeforeExec, - "beego.AfterExec": AfterExec, - "beego.FinishRouter": FinishRouter, - } - - routerHooksMapping = map[int]string{ - BeforeStatic: "beego.BeforeStatic", - BeforeRouter: "beego.BeforeRouter", - BeforeExec: "beego.BeforeExec", - AfterExec: "beego.AfterExec", - FinishRouter: "beego.FinishRouter", - } -) - -const commentPrefix = "commentsRouter_" - -func init() { - pkgLastupdate = make(map[string]int64) -} - -func parserPkg(pkgRealpath, pkgpath string) error { - rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") - commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) - commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" - if !compareFile(pkgRealpath) { - logs.Info(pkgRealpath + " no changed") - return nil - } - genInfoList = make(map[string][]ControllerComments) - fileSet := token.NewFileSet() - astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { - name := info.Name() - return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") - }, parser.ParseComments) - - if err != nil { - return err - } - for _, pkg := range astPkgs { - for _, fl := range pkg.Files { - for _, d := range fl.Decls { - switch specDecl := d.(type) { - case *ast.FuncDecl: - if specDecl.Recv != nil { - exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser - if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) - } - } - } - } - } - } - genRouterCode(pkgRealpath) - savetoFile(pkgRealpath) - return nil -} - -type parsedComment struct { - routerPath string - methods []string - params map[string]parsedParam - filters []parsedFilter - imports []parsedImport -} - -type parsedImport struct { - importPath string - importAlias string -} - -type parsedFilter struct { - pattern string - pos int - filter string - params []bool -} - -type parsedParam struct { - name string - datatype string - location string - defValue string - required bool -} - -func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { - if f.Doc != nil { - parsedComments, err := parseComment(f.Doc.List) - if err != nil { - return err - } - for _, parsedComment := range parsedComments { - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - cc.FilterComments = buildFilters(parsedComment.filters) - cc.ImportComments = buildImports(parsedComment.imports) - genInfoList[key] = append(genInfoList[key], cc) - } - } - } - return nil -} - -func buildImports(pis []parsedImport) []*ControllerImportComments { - var importComments []*ControllerImportComments - - for _, pi := range pis { - importComments = append(importComments, &ControllerImportComments{ - ImportPath: pi.importPath, - ImportAlias: pi.importAlias, - }) - } - - return importComments -} - -func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { - var filterComments []*ControllerFilterComments - - for _, pf := range pfs { - var ( - returnOnOutput bool - resetParams bool - ) - - if len(pf.params) >= 1 { - returnOnOutput = pf.params[0] - } - - if len(pf.params) >= 2 { - resetParams = pf.params[1] - } - - filterComments = append(filterComments, &ControllerFilterComments{ - Filter: pf.filter, - Pattern: pf.pattern, - Pos: pf.pos, - ReturnOnOutput: returnOnOutput, - ResetParams: resetParams, - }) - } - - return filterComments -} - -func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { - result := make([]*param.MethodParam, 0, len(funcParams)) - for _, fparam := range funcParams { - for _, pName := range fparam.Names { - methodParam := buildMethodParam(fparam, pName.Name, pc) - result = append(result, methodParam) - } - } - return result -} - -func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { - options := []param.MethodParamOption{} - if cparam, ok := pc.params[name]; ok { - //Build param from comment info - name = cparam.name - if cparam.required { - options = append(options, param.IsRequired) - } - switch cparam.location { - case "body": - options = append(options, param.InBody) - case "header": - options = append(options, param.InHeader) - case "path": - options = append(options, param.InPath) - } - if cparam.defValue != "" { - options = append(options, param.Default(cparam.defValue)) - } - } else { - if paramInPath(name, pc.routerPath) { - options = append(options, param.InPath) - } - } - return param.New(name, options...) -} - -func paramInPath(name, route string) bool { - return strings.HasSuffix(route, ":"+name) || - strings.Contains(route, ":"+name+"/") -} - -var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) - -func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { - pcs = []*parsedComment{} - params := map[string]parsedParam{} - filters := []parsedFilter{} - imports := []parsedImport{} - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Param") { - pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) - if len(pv) < 4 { - logs.Error("Invalid @Param format. Needs at least 4 parameters") - } - p := parsedParam{} - names := strings.SplitN(pv[0], "=>", 2) - p.name = names[0] - funcParamName := p.name - if len(names) > 1 { - funcParamName = names[1] - } - p.location = pv[1] - p.datatype = pv[2] - switch len(pv) { - case 5: - p.required, _ = strconv.ParseBool(pv[3]) - case 6: - p.defValue = pv[3] - p.required, _ = strconv.ParseBool(pv[4]) - } - params[funcParamName] = p - } - } - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Import") { - iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) - if len(iv) == 0 || len(iv) > 2 { - logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") - continue - } - - p := parsedImport{} - p.importPath = iv[0] - - if len(iv) == 2 { - p.importAlias = iv[1] - } - - imports = append(imports, p) - } - } - -filterLoop: - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Filter") { - fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) - if len(fv) < 3 { - logs.Error("Invalid @Filter format. Needs at least 3 parameters") - continue filterLoop - } - - p := parsedFilter{} - p.pattern = fv[0] - posName := fv[1] - if pos, exists := routerHooks[posName]; exists { - p.pos = pos - } else { - logs.Error("Invalid @Filter pos: ", posName) - continue filterLoop - } - - p.filter = fv[2] - fvParams := fv[3:] - for _, fvParam := range fvParams { - switch fvParam { - case "true": - p.params = append(p.params, true) - case "false": - p.params = append(p.params, false) - default: - logs.Error("Invalid @Filter param: ", fvParam) - continue filterLoop - } - } - - filters = append(filters, p) - } - } - - for _, c := range lines { - var pc = &parsedComment{} - pc.params = params - pc.filters = filters - pc.imports = imports - - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - //pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - //pc.hasGet = strings.Contains(methods, "get") - } - pcs = append(pcs, pc) - } else { - return nil, errors.New("Router information is missing") - } - } - } - return -} - -// direct copy from bee\g_docs.go -// analysis params return []string -// @Param query form string true "The email for login" -// [query form string true "The email for login"] -func getparams(str string) []string { - var s []rune - var j int - var start bool - var r []string - var quoted int8 - for _, c := range str { - if unicode.IsSpace(c) && quoted == 0 { - if !start { - continue - } else { - start = false - j++ - r = append(r, string(s)) - s = make([]rune, 0) - continue - } - } - - start = true - if c == '"' { - quoted ^= 1 - continue - } - s = append(s, c) - } - if len(s) > 0 { - r = append(r, string(s)) - } - return r -} - -func genRouterCode(pkgRealpath string) { - os.Mkdir(getRouterDir(pkgRealpath), 0755) - logs.Info("generate router from comments") - var ( - globalinfo string - globalimport string - sortKey []string - ) - for k := range genInfoList { - sortKey = append(sortKey, k) - } - sort.Strings(sortKey) - for _, k := range sortKey { - cList := genInfoList[k] - sort.Sort(ControllerCommentsSlice(cList)) - for _, c := range cList { - allmethod := "nil" - if len(c.AllowHTTPMethods) > 0 { - allmethod = "[]string{" - for _, m := range c.AllowHTTPMethods { - allmethod += `"` + m + `",` - } - allmethod = strings.TrimRight(allmethod, ",") + "}" - } - - params := "nil" - if len(c.Params) > 0 { - params = "[]map[string]string{" - for _, p := range c.Params { - for k, v := range p { - params = params + `map[string]string{` + k + `:"` + v + `"},` - } - } - params = strings.TrimRight(params, ",") + "}" - } - - methodParams := "param.Make(" - if len(c.MethodParams) > 0 { - lines := make([]string, 0, len(c.MethodParams)) - for _, m := range c.MethodParams { - lines = append(lines, fmt.Sprint(m)) - } - methodParams += "\n " + - strings.Join(lines, ",\n ") + - ",\n " - } - methodParams += ")" - - imports := "" - if len(c.ImportComments) > 0 { - for _, i := range c.ImportComments { - var s string - if i.ImportAlias != "" { - s = fmt.Sprintf(` - %s "%s"`, i.ImportAlias, i.ImportPath) - } else { - s = fmt.Sprintf(` - "%s"`, i.ImportPath) - } - if !strings.Contains(globalimport, s) { - imports += s - } - } - } - - filters := "" - if len(c.FilterComments) > 0 { - for _, f := range c.FilterComments { - filters += fmt.Sprintf(` &beego.ControllerFilter{ - Pattern: "%s", - Pos: %s, - Filter: %s, - ReturnOnOutput: %v, - ResetParams: %v, - },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) - } - } - - if filters == "" { - filters = "nil" - } else { - filters = fmt.Sprintf(`[]*beego.ControllerFilter{ -%s - }`, filters) - } - - globalimport += imports - - globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + `Router: "` + c.Router + `"` + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Filters: ` + filters + `, - Params: ` + params + `}) -` - } - } - - if globalinfo != "" { - f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) - if err != nil { - panic(err) - } - defer f.Close() - - routersDir := AppConfig.DefaultString("routersdir", "routers") - content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) - content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) - content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) - f.WriteString(content) - } -} - -func compareFile(pkgRealpath string) bool { - if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { - return true - } - if utils.FileExists(lastupdateFilename) { - content, err := ioutil.ReadFile(lastupdateFilename) - if err != nil { - return true - } - json.Unmarshal(content, &pkgLastupdate) - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return true - } - if v, ok := pkgLastupdate[pkgRealpath]; ok { - if lastupdate <= v { - return false - } - } - } - return true -} - -func savetoFile(pkgRealpath string) { - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return - } - pkgLastupdate[pkgRealpath] = lastupdate - d, err := json.Marshal(pkgLastupdate) - if err != nil { - return - } - ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) -} - -func getpathTime(pkgRealpath string) (lastupdate int64, err error) { - fl, err := ioutil.ReadDir(pkgRealpath) - if err != nil { - return lastupdate, err - } - for _, f := range fl { - if lastupdate < f.ModTime().UnixNano() { - lastupdate = f.ModTime().UnixNano() - } - } - return lastupdate, nil -} - -func getRouterDir(pkgRealpath string) string { - dir := filepath.Dir(pkgRealpath) - for { - routersDir := AppConfig.DefaultString("routersdir", "routers") - d := filepath.Join(dir, routersDir) - if utils.FileExists(d) { - return d - } - - if r, _ := filepath.Rel(dir, AppPath); r == "." { - return d - } - // Parent dir. - dir = filepath.Dir(dir) - } -} diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go deleted file mode 100644 index 10e25f3f..00000000 --- a/plugins/apiauth/apiauth.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package apiauth provides handlers to enable apiauth support. -// -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/apiauth" -// ) -// -// func main(){ -// // apiauth every request -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) -// beego.Run() -// } -// -// Advanced Usage: -// -// func getAppSecret(appid string) string { -// // get appsecret by appid -// // maybe store in configure, maybe in database -// } -// -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) -// -// Information: -// -// In the request user should include these params in the query -// -// 1. appid -// -// appid is assigned to the application -// -// 2. signature -// -// get the signature use apiauth.Signature() -// -// when you send to server remember use url.QueryEscape() -// -// 3. timestamp: -// -// send the request time, the format is yyyy-mm-dd HH:ii:ss -// -package apiauth - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "fmt" - "net/url" - "sort" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// AppIDToAppSecret is used to get appsecret throw appid -type AppIDToAppSecret func(string) string - -// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret -func APIBasicAuth(appid, appkey string) beego.FilterFunc { - ft := func(aid string) string { - if aid == appid { - return appkey - } - return "" - } - return APISecretAuth(ft, 300) -} - -// APIBaiscAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) beego.FilterFunc { - return APIBasicAuth(appid, appkey) -} - -// APISecretAuth use AppIdToAppSecret verify and -func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { - return func(ctx *context.Context) { - if ctx.Input.Query("appid") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: appid") - return - } - appsecret := f(ctx.Input.Query("appid")) - if appsecret == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("not exist this appid") - return - } - if ctx.Input.Query("signature") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: signature") - return - } - if ctx.Input.Query("timestamp") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: timestamp") - return - } - u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) - if err != nil { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") - return - } - t := time.Now() - if t.Sub(u).Seconds() > float64(timeout) { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timeout! the request time is long ago, please try again") - return - } - if ctx.Input.Query("signature") != - Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("auth failed") - } - } -} - -// Signature used to generate signature with the appsecret/method/params/RequestURI -func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { - var b bytes.Buffer - keys := make([]string, len(params)) - pa := make(map[string]string) - for k, v := range params { - pa[k] = v[0] - keys = append(keys, k) - } - - sort.Strings(keys) - - for _, key := range keys { - if key == "signature" { - continue - } - - val := pa[key] - if key != "" && val != "" { - b.WriteString(key) - b.WriteString(val) - } - } - - stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) - - sha256 := sha256.New - hash := hmac.New(sha256, []byte(appsecret)) - hash.Write([]byte(stringToSign)) - return base64.StdEncoding.EncodeToString(hash.Sum(nil)) -} diff --git a/plugins/apiauth/apiauth_test.go b/plugins/apiauth/apiauth_test.go deleted file mode 100644 index 1f56cb0f..00000000 --- a/plugins/apiauth/apiauth_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package apiauth - -import ( - "net/url" - "testing" -) - -func TestSignature(t *testing.T) { - appsecret := "beego secret" - method := "GET" - RequestURL := "http://localhost/test/url" - params := make(url.Values) - params.Add("arg1", "hello") - params.Add("arg2", "beego") - - signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" - if Signature(appsecret, method, params, RequestURL) != signature { - t.Error("Signature error") - } -} diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go deleted file mode 100644 index c478044a..00000000 --- a/plugins/auth/basic.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package auth provides handlers to enable basic auth support. -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/auth" -// ) -// -// func main(){ -// // authenticate every request -// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) -// beego.Run() -// } -// -// -// Advanced Usage: -// -// func SecretAuth(username, password string) bool { -// return username == "astaxie" && password == "helloBeego" -// } -// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") -// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) -package auth - -import ( - "encoding/base64" - "net/http" - "strings" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -var defaultRealm = "Authorization Required" - -// Basic is the http basic auth -func Basic(username string, password string) beego.FilterFunc { - secrets := func(user, pass string) bool { - return user == username && pass == password - } - return NewBasicAuthenticator(secrets, defaultRealm) -} - -// NewBasicAuthenticator return the BasicAuth -func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { - return func(ctx *context.Context) { - a := &BasicAuth{Secrets: secrets, Realm: Realm} - if username := a.CheckAuth(ctx.Request); username == "" { - a.RequireAuth(ctx.ResponseWriter, ctx.Request) - } - } -} - -// SecretProvider is the SecretProvider function -type SecretProvider func(user, pass string) bool - -// BasicAuth store the SecretProvider and Realm -type BasicAuth struct { - Secrets SecretProvider - Realm string -} - -// CheckAuth Checks the username/password combination from the request. Returns -// either an empty string (authentication failed) or the name of the -// authenticated user. -// Supports MD5 and SHA1 password entries -func (a *BasicAuth) CheckAuth(r *http.Request) string { - s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(s) != 2 || s[0] != "Basic" { - return "" - } - - b, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return "" - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return "" - } - - if a.Secrets(pair[0], pair[1]) { - return pair[0] - } - return "" -} - -// RequireAuth http.Handler for BasicAuth which initiates the authentication process -// (or requires reauthentication). -func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { - w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) - w.WriteHeader(401) - w.Write([]byte("401 Unauthorized\n")) -} diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go deleted file mode 100644 index 9dc0db76..00000000 --- a/plugins/authz/authz.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/authz" -// "github.com/casbin/casbin" -// ) -// -// func main(){ -// // mediate the access for every request -// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) -// beego.Run() -// } -// -// -// Advanced Usage: -// -// func main(){ -// e := casbin.NewEnforcer("authz_model.conf", "") -// e.AddRoleForUser("alice", "admin") -// e.AddPolicy(...) -// -// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) -// beego.Run() -// } -package authz - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/casbin/casbin" - "net/http" -) - -// NewAuthorizer returns the authorizer. -// Use a casbin enforcer as input -func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { - return func(ctx *context.Context) { - a := &BasicAuthorizer{enforcer: e} - - if !a.CheckPermission(ctx.Request) { - a.RequirePermission(ctx.ResponseWriter) - } - } -} - -// BasicAuthorizer stores the casbin handler -type BasicAuthorizer struct { - enforcer *casbin.Enforcer -} - -// GetUserName gets the user name from the request. -// Currently, only HTTP basic authentication is supported -func (a *BasicAuthorizer) GetUserName(r *http.Request) string { - username, _, _ := r.BasicAuth() - return username -} - -// CheckPermission checks the user/method/path combination from the request. -// Returns true (permission granted) or false (permission forbidden) -func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { - user := a.GetUserName(r) - method := r.Method - path := r.URL.Path - return a.enforcer.Enforce(user, path, method) -} - -// RequirePermission returns the 403 Forbidden to the client -func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { - w.WriteHeader(403) - w.Write([]byte("403 Forbidden\n")) -} diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf deleted file mode 100644 index d1b3dbd7..00000000 --- a/plugins/authz/authz_model.conf +++ /dev/null @@ -1,14 +0,0 @@ -[request_definition] -r = sub, obj, act - -[policy_definition] -p = sub, obj, act - -[role_definition] -g = _, _ - -[policy_effect] -e = some(where (p.eft == allow)) - -[matchers] -m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv deleted file mode 100644 index c062dd3e..00000000 --- a/plugins/authz/authz_policy.csv +++ /dev/null @@ -1,7 +0,0 @@ -p, alice, /dataset1/*, GET -p, alice, /dataset1/resource1, POST -p, bob, /dataset2/resource1, * -p, bob, /dataset2/resource2, GET -p, bob, /dataset2/folder1/*, POST -p, dataset1_admin, /dataset1/*, * -g, cathy, dataset1_admin \ No newline at end of file diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go deleted file mode 100644 index 49aed84c..00000000 --- a/plugins/authz/authz_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package authz - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/plugins/auth" - "github.com/casbin/casbin" - "net/http" - "net/http/httptest" - "testing" -) - -func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { - r, _ := http.NewRequest(method, path, nil) - r.SetBasicAuth(user, "123") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - if w.Code != code { - t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) - } -} - -func TestBasic(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) -} - -func TestPathWildcard(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) - testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) - - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) -} - -func TestRBAC(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) - e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) - - // delete all roles on user cathy, so cathy cannot access any resources now. - e.DeleteRolesForUser("cathy") - - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) -} diff --git a/plugins/cors/cors.go b/plugins/cors/cors.go deleted file mode 100644 index 45c327ab..00000000 --- a/plugins/cors/cors.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package cors provides handlers to enable CORS support. -// Usage -// import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/cors" -// ) -// -// func main() { -// // CORS for https://foo.* origins, allowing: -// // - PUT and PATCH methods -// // - Origin header -// // - Credentials share -// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ -// AllowOrigins: []string{"https://*.foo.com"}, -// AllowMethods: []string{"PUT", "PATCH"}, -// AllowHeaders: []string{"Origin"}, -// ExposeHeaders: []string{"Content-Length"}, -// AllowCredentials: true, -// })) -// beego.Run() -// } -package cors - -import ( - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -const ( - headerAllowOrigin = "Access-Control-Allow-Origin" - headerAllowCredentials = "Access-Control-Allow-Credentials" - headerAllowHeaders = "Access-Control-Allow-Headers" - headerAllowMethods = "Access-Control-Allow-Methods" - headerExposeHeaders = "Access-Control-Expose-Headers" - headerMaxAge = "Access-Control-Max-Age" - - headerOrigin = "Origin" - headerRequestMethod = "Access-Control-Request-Method" - headerRequestHeaders = "Access-Control-Request-Headers" -) - -var ( - defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} - // Regex patterns are generated from AllowOrigins. These are used and generated internally. - allowOriginPatterns = []string{} -) - -// Options represents Access Control options. -type Options struct { - // If set, all origins are allowed. - AllowAllOrigins bool - // A list of allowed origins. Wild cards and FQDNs are supported. - AllowOrigins []string - // If set, allows to share auth credentials such as cookies. - AllowCredentials bool - // A list of allowed HTTP methods. - AllowMethods []string - // A list of allowed HTTP headers. - AllowHeaders []string - // A list of exposed HTTP headers. - ExposeHeaders []string - // Max age of the CORS headers. - MaxAge time.Duration -} - -// Header converts options into CORS headers. -func (o *Options) Header(origin string) (headers map[string]string) { - headers = make(map[string]string) - // if origin is not allowed, don't extend the headers - // with CORS headers. - if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { - return - } - - // add allow origin - if o.AllowAllOrigins { - headers[headerAllowOrigin] = "*" - } else { - headers[headerAllowOrigin] = origin - } - - // add allow credentials - headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) - - // add allow methods - if len(o.AllowMethods) > 0 { - headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") - } - - // add allow headers - if len(o.AllowHeaders) > 0 { - headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") - } - - // add exposed header - if len(o.ExposeHeaders) > 0 { - headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") - } - // add a max age header - if o.MaxAge > time.Duration(0) { - headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) - } - return -} - -// PreflightHeader converts options into CORS headers for a preflight response. -func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { - headers = make(map[string]string) - if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { - return - } - // verify if requested method is allowed - for _, method := range o.AllowMethods { - if method == rMethod { - headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") - break - } - } - - // verify if requested headers are allowed - var allowed []string - for _, rHeader := range strings.Split(rHeaders, ",") { - rHeader = strings.TrimSpace(rHeader) - lookupLoop: - for _, allowedHeader := range o.AllowHeaders { - if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { - allowed = append(allowed, rHeader) - break lookupLoop - } - } - } - - headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) - // add allow origin - if o.AllowAllOrigins { - headers[headerAllowOrigin] = "*" - } else { - headers[headerAllowOrigin] = origin - } - - // add allowed headers - if len(allowed) > 0 { - headers[headerAllowHeaders] = strings.Join(allowed, ",") - } - - // add exposed headers - if len(o.ExposeHeaders) > 0 { - headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") - } - // add a max age header - if o.MaxAge > time.Duration(0) { - headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) - } - return -} - -// IsOriginAllowed looks up if the origin matches one of the patterns -// generated from Options.AllowOrigins patterns. -func (o *Options) IsOriginAllowed(origin string) (allowed bool) { - for _, pattern := range allowOriginPatterns { - allowed, _ = regexp.MatchString(pattern, origin) - if allowed { - return - } - } - return -} - -// Allow enables CORS for requests those match the provided options. -func Allow(opts *Options) beego.FilterFunc { - // Allow default headers if nothing is specified. - if len(opts.AllowHeaders) == 0 { - opts.AllowHeaders = defaultAllowHeaders - } - - for _, origin := range opts.AllowOrigins { - pattern := regexp.QuoteMeta(origin) - pattern = strings.Replace(pattern, "\\*", ".*", -1) - pattern = strings.Replace(pattern, "\\?", ".", -1) - allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") - } - - return func(ctx *context.Context) { - var ( - origin = ctx.Input.Header(headerOrigin) - requestedMethod = ctx.Input.Header(headerRequestMethod) - requestedHeaders = ctx.Input.Header(headerRequestHeaders) - // additional headers to be added - // to the response. - headers map[string]string - ) - - if ctx.Input.Method() == "OPTIONS" && - (requestedMethod != "" || requestedHeaders != "") { - headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) - for key, value := range headers { - ctx.Output.Header(key, value) - } - ctx.ResponseWriter.WriteHeader(http.StatusOK) - return - } - headers = opts.Header(origin) - - for key, value := range headers { - ctx.Output.Header(key, value) - } - } -} diff --git a/plugins/cors/cors_test.go b/plugins/cors/cors_test.go deleted file mode 100644 index 34039143..00000000 --- a/plugins/cors/cors_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cors - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header -type HTTPHeaderGuardRecorder struct { - *httptest.ResponseRecorder - savedHeaderMap http.Header -} - -// NewRecorder return HttpHeaderGuardRecorder -func NewRecorder() *HTTPHeaderGuardRecorder { - return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} -} - -func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { - gr.ResponseRecorder.WriteHeader(code) - gr.savedHeaderMap = gr.ResponseRecorder.Header() -} - -func (gr *HTTPHeaderGuardRecorder) Header() http.Header { - if gr.savedHeaderMap != nil { - // headers were written. clone so we don't get updates - clone := make(http.Header) - for k, v := range gr.savedHeaderMap { - clone[k] = v - } - return clone - } - return gr.ResponseRecorder.Header() -} - -func Test_AllowAll(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { - t.Errorf("Allow-Origin header should be *") - } -} - -func Test_AllowRegexMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://bar.foo.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != origin { - t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) - } -} - -func Test_AllowRegexNoMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://ww.foo.com.evil.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != "" { - t.Errorf("Allow-Origin header should not exist, found %v", headerValue) - } -} - -func Test_OtherHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - ExposeHeaders: []string{"Content-Length", "Hello"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) - methodsVal := recorder.HeaderMap.Get(headerAllowMethods) - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) - maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) - - if credentialsVal != "true" { - t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) - } - - if methodsVal != "PATCH,GET" { - t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) - } - - if headersVal != "Origin,X-whatever" { - t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) - } - - if exposedHeadersVal != "Content-Length,Hello" { - t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) - } - - if maxAgeVal != "300" { - t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) - } -} - -func Test_DefaultAllowHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - if headersVal != "Origin,Accept,Content-Type,Authorization" { - t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) - } -} - -func Test_Preflight(t *testing.T) { - recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowMethods: []string{"PUT", "PATCH"}, - AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, - })) - - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - r, _ := http.NewRequest("OPTIONS", "/foo", nil) - r.Header.Add(headerRequestMethod, "PUT") - r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") - handler.ServeHTTP(recorder, r) - - headers := recorder.Header() - methodsVal := headers.Get(headerAllowMethods) - headersVal := headers.Get(headerAllowHeaders) - originVal := headers.Get(headerAllowOrigin) - - if methodsVal != "PUT,PATCH" { - t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) - } - - if !strings.Contains(headersVal, "X-whatever") { - t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) - } - - if !strings.Contains(headersVal, "x-casesensitive") { - t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) - } - - if originVal != "*" { - t.Errorf("Allow-Origin is expected to be *, found %v", originVal) - } - - if recorder.Code != http.StatusOK { - t.Errorf("Status code is expected to be 200, found %d", recorder.Code) - } -} - -func Benchmark_WithoutCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} - -func Benchmark_WithCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} diff --git a/policy.go b/policy.go deleted file mode 100644 index 358a0539..00000000 --- a/policy.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2016 beego authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - - "github.com/astaxie/beego/context" -) - -// PolicyFunc defines a policy function which is invoked before the controller handler is executed. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type PolicyFunc func(*context.Context) - -// FindPolicy Find Router info for URL -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { - var urlPath = cont.Input.URL() - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - httpMethod := cont.Input.Method() - isWildcard := false - // Find policy for current method - t, ok := p.policies[httpMethod] - // If not found - find policy for whole controller - if !ok { - t, ok = p.policies["*"] - isWildcard = true - } - if ok { - runObjects := t.Match(urlPath, cont) - if r, ok := runObjects.([]PolicyFunc); ok { - return r - } else if !isWildcard { - // If no policies found and we checked not for "*" method - try to find it - t, ok = p.policies["*"] - if ok { - runObjects = t.Match(urlPath, cont) - if r, ok = runObjects.([]PolicyFunc); ok { - return r - } - } - } - } - return nil -} - -func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { - method = strings.ToUpper(method) - p.enablePolicy = true - if !BConfig.RouterCaseSensitive { - pattern = strings.ToLower(pattern) - } - if t, ok := p.policies[method]; ok { - t.AddRouter(pattern, r) - } else { - t := NewTree() - t.AddRouter(pattern, r) - p.policies[method] = t - } -} - -// Policy Register new policy in beego -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Policy(pattern, method string, policy ...PolicyFunc) { - BeeApp.Handlers.addToPolicy(method, pattern, policy...) -} - -// Find policies and execute if were found -func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { - if !p.enablePolicy { - return false - } - // Find Policy for method - policyList := p.FindPolicy(cont) - if len(policyList) > 0 { - // Run policies - for _, runPolicy := range policyList { - runPolicy(cont) - if cont.ResponseWriter.Started { - return true - } - } - return false - } - return false -} diff --git a/router.go b/router.go deleted file mode 100644 index 1be495ab..00000000 --- a/router.go +++ /dev/null @@ -1,1085 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "net/http" - "os" - "path" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "time" - - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// default filter execution points -const ( - BeforeStatic = iota - BeforeRouter - BeforeExec - AfterExec - FinishRouter -) - -const ( - routerTypeBeego = iota - routerTypeRESTFul - routerTypeHandler -) - -var ( - // HTTPMETHOD list the supported http methods. - // Deprecated: using pkg/, we will delete this in v2.1.0 - HTTPMETHOD = map[string]bool{ - "GET": true, - "POST": true, - "PUT": true, - "DELETE": true, - "PATCH": true, - "OPTIONS": true, - "HEAD": true, - "TRACE": true, - "CONNECT": true, - "MKCOL": true, - "COPY": true, - "MOVE": true, - "PROPFIND": true, - "PROPPATCH": true, - "LOCK": true, - "UNLOCK": true, - } - // these beego.Controller's methods shouldn't reflect to AutoRouter - exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", - "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", - "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", - "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", - "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", - "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", - "GetControllerAndAction", "ServeFormatted"} - - urlPlaceholder = "{{placeholder}}" - // DefaultAccessLogFilter will skip the accesslog if return true - // Deprecated: using pkg/, we will delete this in v2.1.0 - DefaultAccessLogFilter FilterHandler = &logFilter{} -) - -// FilterHandler is an interface for -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterHandler interface { - Filter(*beecontext.Context) bool -} - -// default log filter static file will not show -type logFilter struct { -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (l *logFilter) Filter(ctx *beecontext.Context) bool { - requestPath := path.Clean(ctx.Request.URL.Path) - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - return true - } - for prefix := range BConfig.WebConfig.StaticDir { - if strings.HasPrefix(requestPath, prefix) { - return true - } - } - return false -} - -// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExceptMethodAppend(action string) { - exceptMethod = append(exceptMethod, action) -} - -// ControllerInfo holds information about the controller. -type ControllerInfo struct { - pattern string - controllerType reflect.Type - methods map[string]string - handler http.Handler - runFunction FilterFunc - routerType int - initialize func() ControllerInterface - methodParams []*param.MethodParam -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *ControllerInfo) GetPattern() string { - return c.pattern -} - -// ControllerRegister containers registered router rules, controller handlers and filters. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerRegister struct { - routers map[string]*Tree - enablePolicy bool - policies map[string]*Tree - enableFilter bool - filters [FinishRouter + 1][]*FilterRouter - pool sync.Pool -} - -// NewControllerRegister returns a new ControllerRegister. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ - routers: make(map[string]*Tree), - policies: make(map[string]*Tree), - pool: sync.Pool{ - New: func() interface{} { - return beecontext.NewContext() - }, - }, - } -} - -// Add controller handler and pattern rules to ControllerRegister. -// usage: -// default methods is the same name as method -// Add("/user",&UserController{}) -// Add("/api/list",&RestController{},"*:ListFood") -// Add("/api/create",&RestController{},"post:CreateFood") -// Add("/api/update",&RestController{},"put:UpdateFood") -// Add("/api/delete",&RestController{},"delete:DeleteFood") -// Add("/api",&RestController{},"get,post:ApiFunc" -// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) -} - -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") - } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } - } - } - } - - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t - route.initialize = func() ControllerInterface { - vc := reflect.New(route.controllerType) - execController, ok := vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - - elemVal := reflect.ValueOf(c).Elem() - elemType := reflect.TypeOf(c).Elem() - execElem := reflect.ValueOf(execController).Elem() - - numOfFields := elemVal.NumField() - for i := 0; i < numOfFields; i++ { - fieldType := elemType.Field(i) - elemField := execElem.FieldByName(fieldType.Name) - if elemField.CanSet() { - fieldVal := elemVal.Field(i) - elemField.Set(fieldVal) - } - } - - return execController - } - - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } - } -} - -func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { - if !BConfig.RouterCaseSensitive { - pattern = strings.ToLower(pattern) - } - if t, ok := p.routers[method]; ok { - t.AddRouter(pattern, r) - } else { - t := NewTree() - t.AddRouter(pattern, r) - p.routers[method] = t - } -} - -// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller -// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Include(cList ...ControllerInterface) { - if BConfig.RunMode == DEV { - skip := make(map[string]bool, 10) - wgopath := utils.GetGOPATHs() - go111module := os.Getenv(`GO111MODULE`) - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - // for go modules - if go111module == `on` { - pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) - if utils.FileExists(pkgpath) { - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } else { - if len(wgopath) == 0 { - panic("you are in dev mode. So please set gopath") - } - pkgpath := "" - for _, wg := range wgopath { - wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) - if utils.FileExists(wg) { - pkgpath = wg - break - } - } - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } - } - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - key := t.PkgPath() + ":" + t.Name() - if comm, ok := GlobalControllerRouter[key]; ok { - for _, a := range comm { - for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) - } - - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) - } - } - } -} - -// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context -// And don't forget to give back context to pool -// example: -// ctx := p.GetContext() -// ctx.Reset(w, q) -// defer p.GiveBackContext(ctx) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) GetContext() *beecontext.Context { - return p.pool.Get().(*beecontext.Context) -} - -// GiveBackContext put the ctx into pool so that it could be reuse -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { - // clear input cached data - ctx.Input.Clear() - // clear output cached data - ctx.Output.Clear() - p.pool.Put(ctx) -} - -// Get add get method -// usage: -// Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Get(pattern string, f FilterFunc) { - p.AddMethod("get", pattern, f) -} - -// Post add post method -// usage: -// Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Post(pattern string, f FilterFunc) { - p.AddMethod("post", pattern, f) -} - -// Put add put method -// usage: -// Put("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Put(pattern string, f FilterFunc) { - p.AddMethod("put", pattern, f) -} - -// Delete add delete method -// usage: -// Delete("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { - p.AddMethod("delete", pattern, f) -} - -// Head add head method -// usage: -// Head("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Head(pattern string, f FilterFunc) { - p.AddMethod("head", pattern, f) -} - -// Patch add patch method -// usage: -// Patch("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { - p.AddMethod("patch", pattern, f) -} - -// Options add options method -// usage: -// Options("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Options(pattern string, f FilterFunc) { - p.AddMethod("options", pattern, f) -} - -// Any add all method -// usage: -// Any("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Any(pattern string, f FilterFunc) { - p.AddMethod("*", pattern, f) -} - -// AddMethod add http method router -// usage: -// AddMethod("get","/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } - route.methods = methods - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } -} - -// Handler add user defined Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h - if len(options) > 0 { - if _, ok := options[0].(bool); ok { - pattern = path.Join(pattern, "?:all(.*)") - } - } - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } -} - -// AddAuto router to ControllerRegister. -// example beego.AddAuto(&MainContorlller{}), -// MainController has method List and Page. -// visit the url /main/list to execute List function -// /main/page to execute Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddAuto(c ControllerInterface) { - p.AddAutoPrefix("/", c) -} - -// AddAutoPrefix Add auto router to ControllerRegister with prefix. -// example beego.AddAutoPrefix("/admin",&MainContorlller{}), -// MainController has method List and Page. -// visit the url /admin/main/list to execute List function -// /admin/main/page to execute Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { - reflectVal := reflect.ValueOf(c) - rt := reflectVal.Type() - ct := reflect.Indirect(reflectVal).Type() - controllerName := strings.TrimSuffix(ct.Name(), "Controller") - for i := 0; i < rt.NumMethod(); i++ { - if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct - pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") - patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") - patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) - patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - p.addToRouter(m, patternInit, route) - p.addToRouter(m, patternFix, route) - p.addToRouter(m, patternFixInit, route) - } - } - } -} - -// InsertFilter Add a FilterFunc with pattern rule and action constant. -// params is for: -// 1. setting the returnOnOutput value (false allows multiple filters to execute) -// 2. determining whether or not params need to be reset. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) - return p.insertFilterRouter(pos, mr) -} - -// add Filter into -func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { - if pos < BeforeStatic || pos > FinishRouter { - return errors.New("can not find your filter position") - } - p.enableFilter = true - p.filters[pos] = append(p.filters[pos], mr) - return nil -} - -// URLFor does another controller handler in this request function. -// it can access any controller method. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { - paths := strings.Split(endpoint, ".") - if len(paths) <= 1 { - logs.Warn("urlfor endpoint must like path.controller.method") - return "" - } - if len(values)%2 != 0 { - logs.Warn("urlfor params must key-value pair") - return "" - } - params := make(map[string]string) - if len(values) > 0 { - key := "" - for k, v := range values { - if k%2 == 0 { - key = fmt.Sprint(v) - } else { - params[key] = fmt.Sprint(v) - } - } - } - controllerName := strings.Join(paths[:len(paths)-1], "/") - methodName := paths[len(paths)-1] - for m, t := range p.routers { - ok, url := p.getURL(t, "/", controllerName, methodName, params, m) - if ok { - return url - } - } - return "" -} - -func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { - for _, subtree := range t.fixrouters { - u := path.Join(url, subtree.prefix) - ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - if t.wildcard != nil { - u := path.Join(url, urlPlaceholder) - ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - for _, l := range t.leaves { - if c, ok := l.runObject.(*ControllerInfo); ok { - if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { - find := false - if HTTPMETHOD[strings.ToUpper(methodName)] { - if len(c.methods) == 0 { - find = true - } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { - find = true - } else if m, ok = c.methods["*"]; ok && m == methodName { - find = true - } - } - if !find { - for m, md := range c.methods { - if (m == "*" || m == httpMethod) && md == methodName { - find = true - } - } - } - if find { - if l.regexps == nil { - if len(l.wildcards) == 0 { - return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) - } - if len(l.wildcards) == 1 { - if v, ok := params[l.wildcards[0]]; ok { - delete(params, l.wildcards[0]) - return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) - } - return false, "" - } - if len(l.wildcards) == 3 && l.wildcards[0] == "." { - if p, ok := params[":path"]; ok { - if e, isok := params[":ext"]; isok { - delete(params, ":path") - delete(params, ":ext") - return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) - } - } - } - canSkip := false - for _, v := range l.wildcards { - if v == ":" { - canSkip = true - continue - } - if u, ok := params[v]; ok { - delete(params, v) - url = strings.Replace(url, urlPlaceholder, u, 1) - } else { - if canSkip { - canSkip = false - continue - } - return false, "" - } - } - return true, url + toURL(params) - } - var i int - var startReg bool - regURL := "" - for _, v := range strings.Trim(l.regexps.String(), "^$") { - if v == '(' { - startReg = true - continue - } else if v == ')' { - startReg = false - if v, ok := params[l.wildcards[i]]; ok { - delete(params, l.wildcards[i]) - regURL = regURL + v - i++ - } else { - break - } - } else if !startReg { - regURL = string(append([]rune(regURL), v)) - } - } - if l.regexps.MatchString(regURL) { - ps := strings.Split(regURL, "/") - for _, p := range ps { - url = strings.Replace(url, urlPlaceholder, p, 1) - } - return true, url + toURL(params) - } - } - } - } - } - - return false, "" -} - -func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { - var preFilterParams map[string]string - for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - } - return false -} - -// Implement http.Handler interface. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - startTime := time.Now() - var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool - ) - context := p.GetContext() - - context.Reset(rw, r) - - defer p.GiveBackContext(context) - if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) - } - - context.Output.EnableGzip = BConfig.EnableGzip - - if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) - } - - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - - // filter wrong http method - if !HTTPMETHOD[r.Method] { - exception("405", context) - goto Admin - } - - // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { - goto Admin - } - - serverStaticRouter(context) - - if context.ResponseWriter.Started { - findRouter = true - goto Admin - } - - if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { - // connection will close if the incoming data are larger (RFC 7231, 6.5.11) - if r.ContentLength > BConfig.MaxMemory { - logs.Error(errors.New("payload too large")) - exception("413", context) - goto Admin - } - context.Input.CopyBody(BConfig.MaxMemory) - } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) - } - - // session init - if BConfig.WebConfig.Session.SessionOn { - var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) - if err != nil { - logs.Error(err) - exception("503", context) - goto Admin - } - defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) - } - }() - } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { - goto Admin - } - // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { - findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController - } else { - routerInfo, findRouter = p.FindRouter(context) - } - - // if no matches to url, throw a not found exception - if !findRouter { - exception("404", context) - goto Admin - } - if splat := context.Input.Param(":splat"); splat != "" { - for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) - } - } - - if routerInfo != nil { - // store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) - } - - // execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { - goto Admin - } - - // check policies - if p.execPolicy(context, urlPath) { - goto Admin - } - - if routerInfo != nil { - if routerInfo.routerType == routerTypeRESTFul { - if _, ok := routerInfo.methods[r.Method]; ok { - isRunnable = true - routerInfo.runFunction(context) - } else { - exception("405", context) - goto Admin - } - } else if routerInfo.routerType == routerTypeHandler { - isRunnable = true - routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) - } else { - runRouter = routerInfo.controllerType - methodParams = routerInfo.methodParams - method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { - method = http.MethodPut - } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { - method = http.MethodDelete - } - if m, ok := routerInfo.methods[method]; ok { - runMethod = m - } else if m, ok = routerInfo.methods["*"]; ok { - runMethod = m - } else { - runMethod = method - } - } - } - - // also defined runRouter & runMethod from filter - if !isRunnable { - // Invoke the request handler - var execController ControllerInterface - if routerInfo != nil && routerInfo.initialize != nil { - execController = routerInfo.initialize() - } else { - vc := reflect.New(runRouter) - var ok bool - execController, ok = vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - } - - // call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) - - // call prepare function - execController.Prepare() - - // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if BConfig.WebConfig.EnableXSRF { - execController.XSRFToken() - if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { - execController.CheckXSRFCookie() - } - } - - execController.URLMapping() - - if !context.ResponseWriter.Started { - // exec main logic - switch runMethod { - case http.MethodGet: - execController.Get() - case http.MethodPost: - execController.Post() - case http.MethodDelete: - execController.Delete() - case http.MethodPut: - execController.Put() - case http.MethodHead: - execController.Head() - case http.MethodPatch: - execController.Patch() - case http.MethodOptions: - execController.Options() - case http.MethodTrace: - execController.Trace() - default: - if !execController.HandlerFunc(runMethod) { - vc := reflect.ValueOf(execController) - method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) - out := method.Call(in) - - // For backward compatibility we only handle response if we had incoming methodParams - if methodParams != nil { - p.handleParamResponse(context, execController, out) - } - } - } - - // render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { - if BConfig.WebConfig.AutoRender { - if err := execController.Render(); err != nil { - logs.Error(err) - } - } - } - } - - // finish all runRouter. release resource - execController.Finish() - } - - // execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { - goto Admin - } - - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { - goto Admin - } - -Admin: - // admin module record QPS - - statusCode := context.ResponseWriter.Status - if statusCode == 0 { - statusCode = 200 - } - - LogAccess(context, &startTime, statusCode) - - timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur - if BConfig.Listen.EnableAdmin { - pattern := "" - if routerInfo != nil { - pattern = routerInfo.pattern - } - - if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { - routerName := "" - if runRouter != nil { - routerName = runRouter.Name() - } - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) - } - } - - if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { - match := map[bool]string{true: "match", false: "nomatch"} - devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), - logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), - timeDur.String(), - match[findRouter], - logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), - r.URL.Path) - if routerInfo != nil { - devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) - } - - logs.Debug(devInfo) - } - // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) - } -} - -func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { - // looping in reverse order for the case when both error and value are returned and error sets the response status code - for i := len(results) - 1; i >= 0; i-- { - result := results[i] - if result.Kind() != reflect.Interface || !result.IsNil() { - resultValue := result.Interface() - context.RenderMethodResult(resultValue) - } - } - if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { - context.Output.SetStatus(200) - } -} - -// FindRouter Find Router info for URL -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { - var urlPath = context.Input.URL() - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - httpMethod := context.Input.Method() - if t, ok := p.routers[httpMethod]; ok { - runObject := t.Match(urlPath, context) - if r, ok := runObject.(*ControllerInfo); ok { - return r, true - } - } - return -} - -func toURL(params map[string]string) string { - if len(params) == 0 { - return "" - } - u := "?" - for k, v := range params { - u += k + "=" + v + "&" - } - return strings.TrimRight(u, "&") -} - -// LogAccess logging info HTTP Access -// Deprecated: using pkg/, we will delete this in v2.1.0 -func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { - // Skip logging if AccessLogs config is false - if !BConfig.Log.AccessLogs { - return - } - // Skip logging static requests unless EnableStaticLogs config is true - if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { - return - } - var ( - requestTime time.Time - elapsedTime time.Duration - r = ctx.Request - ) - if startTime != nil { - requestTime = *startTime - elapsedTime = time.Since(*startTime) - } - record := &logs.AccessLogRecord{ - RemoteAddr: ctx.Input.IP(), - RequestTime: requestTime, - RequestMethod: r.Method, - Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), - ServerProtocol: r.Proto, - Host: r.Host, - Status: statusCode, - ElapsedTime: elapsedTime, - HTTPReferrer: r.Header.Get("Referer"), - HTTPUserAgent: r.Header.Get("User-Agent"), - RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: r.ContentLength, - } - logs.AccessLog(record, BConfig.Log.AccessLogsFormat) -} diff --git a/router_test.go b/router_test.go deleted file mode 100644 index 8ec7927a..00000000 --- a/router_test.go +++ /dev/null @@ -1,732 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" -) - -type TestController struct { - Controller -} - -func (tc *TestController) Get() { - tc.Data["Username"] = "astaxie" - tc.Ctx.Output.Body([]byte("ok")) -} - -func (tc *TestController) Post() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) -} - -func (tc *TestController) Param() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) -} - -func (tc *TestController) List() { - tc.Ctx.Output.Body([]byte("i am list")) -} - -func (tc *TestController) Params() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) -} - -func (tc *TestController) Myext() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) -} - -func (tc *TestController) GetURL() { - tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) -} - -func (tc *TestController) GetParams() { - tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + - tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) -} - -func (tc *TestController) GetManyRouter() { - tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) -} - -func (tc *TestController) GetEmptyBody() { - var res []byte - tc.Ctx.Output.Body(res) -} - -type JSONController struct { - Controller -} - -func (jc *JSONController) Prepare() { - jc.Data["json"] = "prepare" - jc.ServeJSON(true) -} - -func (jc *JSONController) Get() { - jc.Data["Username"] = "astaxie" - jc.Ctx.Output.Body([]byte("ok")) -} - -func TestUrlFor(t *testing.T) { - handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") - if a := handler.URLFor("TestController.List"); a != "/api/list" { - logs.Info(a) - t.Errorf("TestController.List must equal to /api/list") - } - if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { - t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) - } -} - -func TestUrlFor3(t *testing.T) { - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { - t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) - } - if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { - t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) - } -} - -func TestUrlFor2(t *testing.T) { - handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") - handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) - if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { - logs.Info(handler.URLFor("TestController.GetURL")) - t.Errorf("TestController.List must equal to /v1/astaxie/edit") - } - - if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != - "/v1/za/cms_12_123.html" { - logs.Info(handler.URLFor("TestController.List")) - t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") - } - if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != - "/v1/za_cms/ttt_12_123.html" { - logs.Info(handler.URLFor("TestController.Param")) - t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") - } - if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", - ":title", "aaaa", ":entid", "aaaa") != - "/1111/11/aaaa/aaaa" { - logs.Info(handler.URLFor("TestController.Get")) - t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") - } -} - -func TestUserFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/api/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestPostFunc(t *testing.T) { - r, _ := http.NewRequest("POST", "/astaxie", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/:name", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "astaxie" { - t.Errorf("post func should astaxie") - } -} - -func TestAutoFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestAutoFunc2(t *testing.T) { - r, _ := http.NewRequest("GET", "/Test/List", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestAutoFuncParams(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "20091112" { - t.Errorf("user define func can't run") - } -} - -func TestAutoExtFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/myext.json", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "json" { - t.Errorf("user define func can't run") - } -} - -func TestRouteOk(t *testing.T) { - - r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") - handler.ServeHTTP(w, r) - body := w.Body.String() - if body != "anderson+thomas+kungfu" { - t.Errorf("url param set to [%s];", body) - } -} - -func TestManyRoute(t *testing.T) { - - r, _ := http.NewRequest("GET", "/beego32-12.html", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") - handler.ServeHTTP(w, r) - - body := w.Body.String() - - if body != "3212" { - t.Errorf("url param set to [%s];", body) - } -} - -// Test for issue #1669 -func TestEmptyResponse(t *testing.T) { - - r, _ := http.NewRequest("GET", "/beego-empty.html", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") - handler.ServeHTTP(w, r) - - if body := w.Body.String(); body != "" { - t.Error("want empty body") - } -} - -func TestNotFound(t *testing.T) { - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.ServeHTTP(w, r) - - if w.Code != http.StatusNotFound { - t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound) - } -} - -// TestStatic tests the ability to serve static -// content from the filesystem -func TestStatic(t *testing.T) { - r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.ServeHTTP(w, r) - - if w.Code != 404 { - t.Errorf("handler.Static failed to serve file") - } -} - -func TestPrepare(t *testing.T) { - r, _ := http.NewRequest("GET", "/json/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/json/list", &JSONController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != `"prepare"` { - t.Errorf(w.Body.String() + "user define func can't run") - } -} - -func TestAutoPrefix(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/test/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAutoPrefix("/admin", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("TestAutoPrefix can't run") - } -} - -func TestRouterGet(t *testing.T) { - r, _ := http.NewRequest("GET", "/user", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Get("/user", func(ctx *context.Context) { - ctx.Output.Body([]byte("Get userlist")) - }) - handler.ServeHTTP(w, r) - if w.Body.String() != "Get userlist" { - t.Errorf("TestRouterGet can't run") - } -} - -func TestRouterPost(t *testing.T) { - r, _ := http.NewRequest("POST", "/user/123", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - handler.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestRouterPost can't run") - } -} - -func sayhello(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("sayhello")) -} - -func TestRouterHandler(t *testing.T) { - r, _ := http.NewRequest("POST", "/sayhi", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Handler("/sayhi", http.HandlerFunc(sayhello)) - handler.ServeHTTP(w, r) - if w.Body.String() != "sayhello" { - t.Errorf("TestRouterHandler can't run") - } -} - -func TestRouterHandlerAll(t *testing.T) { - r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) - handler.ServeHTTP(w, r) - if w.Body.String() != "sayhello" { - t.Errorf("TestRouterHandler can't run") - } -} - -// -// Benchmarks NewApp: -// - -func beegoFilterFunc(ctx *context.Context) { - ctx.WriteString("hello") -} - -type AdminController struct { - Controller -} - -func (a *AdminController) Get() { - a.Ctx.WriteString("hello") -} - -func TestRouterFunc(t *testing.T) { - mux := NewControllerRegister() - mux.Get("/action", beegoFilterFunc) - mux.Post("/action", beegoFilterFunc) - rw, r := testRequest("GET", "/action") - mux.ServeHTTP(rw, r) - if rw.Body.String() != "hello" { - t.Errorf("TestRouterFunc can't run") - } -} - -func BenchmarkFunc(b *testing.B) { - mux := NewControllerRegister() - mux.Get("/action", beegoFilterFunc) - rw, r := testRequest("GET", "/action") - b.ResetTimer() - for i := 0; i < b.N; i++ { - mux.ServeHTTP(rw, r) - } -} - -func BenchmarkController(b *testing.B) { - mux := NewControllerRegister() - mux.Add("/action", &AdminController{}) - rw, r := testRequest("GET", "/action") - b.ResetTimer() - for i := 0; i < b.N; i++ { - mux.ServeHTTP(rw, r) - } -} - -func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) { - request, _ := http.NewRequest(method, path, nil) - recorder := httptest.NewRecorder() - - return recorder, request -} - -// Expectation: A Filter with the correct configuration should be created given -// specific parameters. -func TestInsertFilter(t *testing.T) { - testName := "TestInsertFilter" - - mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) - if !mux.filters[BeforeRouter][0].returnOnOutput { - t.Errorf( - "%s: passing no variadic params should set returnOnOutput to true", - testName) - } - if mux.filters[BeforeRouter][0].resetParams { - t.Errorf( - "%s: passing no variadic params should set resetParams to false", - testName) - } - - mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) - if mux.filters[BeforeRouter][0].returnOnOutput { - t.Errorf( - "%s: passing false as 1st variadic param should set returnOnOutput to false", - testName) - } - - mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) - if !mux.filters[BeforeRouter][0].resetParams { - t.Errorf( - "%s: passing true as 2nd variadic param should set resetParams to true", - testName) - } -} - -// Expectation: the second variadic arg should cause the execution of the filter -// to preserve the parameters from before its execution. -func TestParamResetFilter(t *testing.T) { - testName := "TestParamResetFilter" - route := "/beego/*" // splat - path := "/beego/routes/routes" - - mux := NewControllerRegister() - - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) - - mux.Get(route, beegoHandleResetParams) - - rw, r := testRequest("GET", path) - mux.ServeHTTP(rw, r) - - // The two functions, `beegoResetParams` and `beegoHandleResetParams` add - // a response header of `Splat`. The expectation here is that that Header - // value should match what the _request's_ router set, not the filter's. - - headers := rw.Result().Header - if len(headers["Splat"]) != 1 { - t.Errorf( - "%s: There was an error in the test. Splat param not set in Header", - testName) - } - if headers["Splat"][0] != "routes/routes" { - t.Errorf( - "%s: expected `:splat` param to be [routes/routes] but it was [%s]", - testName, headers["Splat"][0]) - } -} - -// Execution point: BeforeRouter -// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle -func TestFilterBeforeRouter(t *testing.T) { - testName := "TestFilterBeforeRouter" - url := "/beforeRouter" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "BeforeRouter1") { - t.Errorf(testName + " BeforeRouter did not run") - } - if strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " BeforeRouter did not return properly") - } -} - -// Execution point: BeforeExec -// expectation: only BeforeExec function is executed, match as router determines route only -func TestFilterBeforeExec(t *testing.T) { - testName := "TestFilterBeforeExec" - url := "/beforeExec" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "BeforeExec1") { - t.Errorf(testName + " BeforeExec did not run") - } - if strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " BeforeExec did not return properly") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } -} - -// Execution point: AfterExec -// expectation: only AfterExec function is executed, match as router handles -func TestFilterAfterExec(t *testing.T) { - testName := "TestFilterAfterExec" - url := "/afterExec" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "AfterExec1") { - t.Errorf(testName + " AfterExec did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeExec") { - t.Errorf(testName + " BeforeExec ran in error") - } -} - -// Execution point: FinishRouter -// expectation: only FinishRouter function is executed, match as router handles -func TestFilterFinishRouter(t *testing.T) { - testName := "TestFilterFinishRouter" - url := "/finishRouter" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if strings.Contains(rw.Body.String(), "AfterExec1") { - t.Errorf(testName + " AfterExec ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeExec") { - t.Errorf(testName + " BeforeExec ran in error") - } -} - -// Execution point: FinishRouter -// expectation: only first FinishRouter function is executed, match as router handles -func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { - testName := "TestFilterFinishRouterMultiFirstOnly" - url := "/finishRouterMultiFirstOnly" - - mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter1 did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - // not expected in body - if strings.Contains(rw.Body.String(), "FinishRouter2") { - t.Errorf(testName + " FinishRouter2 did run") - } -} - -// Execution point: FinishRouter -// expectation: both FinishRouter functions execute, match as router handles -func TestFilterFinishRouterMulti(t *testing.T) { - testName := "TestFilterFinishRouterMulti" - url := "/finishRouterMulti" - - mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter1 did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if !strings.Contains(rw.Body.String(), "FinishRouter2") { - t.Errorf(testName + " FinishRouter2 did not run properly") - } -} - -func beegoFilterNoOutput(ctx *context.Context) { -} - -func beegoBeforeRouter1(ctx *context.Context) { - ctx.WriteString("|BeforeRouter1") -} - -func beegoBeforeExec1(ctx *context.Context) { - ctx.WriteString("|BeforeExec1") -} - -func beegoAfterExec1(ctx *context.Context) { - ctx.WriteString("|AfterExec1") -} - -func beegoFinishRouter1(ctx *context.Context) { - ctx.WriteString("|FinishRouter1") -} - -func beegoFinishRouter2(ctx *context.Context) { - ctx.WriteString("|FinishRouter2") -} - -func beegoResetParams(ctx *context.Context) { - ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) -} - -func beegoHandleResetParams(ctx *context.Context) { - ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) -} - -// YAML -type YAMLController struct { - Controller -} - -func (jc *YAMLController) Prepare() { - jc.Data["yaml"] = "prepare" - jc.ServeYAML() -} - -func (jc *YAMLController) Get() { - jc.Data["Username"] = "astaxie" - jc.Ctx.Output.Body([]byte("ok")) -} - -func TestYAMLPrepare(t *testing.T) { - r, _ := http.NewRequest("GET", "/yaml/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/yaml/list", &YAMLController{}) - handler.ServeHTTP(w, r) - if strings.TrimSpace(w.Body.String()) != "prepare" { - t.Errorf(w.Body.String()) - } -} - -func TestRouterEntityTooLargeCopyBody(t *testing.T) { - _MaxMemory := BConfig.MaxMemory - _CopyRequestBody := BConfig.CopyRequestBody - BConfig.CopyRequestBody = true - BConfig.MaxMemory = 20 - - b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) - r, _ := http.NewRequest("POST", "/user/123", b) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - handler.ServeHTTP(w, r) - - BConfig.CopyRequestBody = _CopyRequestBody - BConfig.MaxMemory = _MaxMemory - - if w.Code != http.StatusRequestEntityTooLarge { - t.Errorf("TestRouterRequestEntityTooLarge can't run") - } -} diff --git a/session/README.md b/session/README.md deleted file mode 100644 index 6d0a297e..00000000 --- a/session/README.md +++ /dev/null @@ -1,114 +0,0 @@ -session -============== - -session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. - -## How to install? - - go get github.com/astaxie/beego/session - - -## What providers are supported? - -As of now this session manager support memory, file, Redis and MySQL. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/session" - ) - -Then in you web app init the global session manager - - var globalSessions *session.Manager - -* Use **memory** as provider: - - func init() { - globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) - go globalSessions.GC() - } - -* Use **file** as provider, the last param is the path where you want file to be stored: - - func init() { - globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) - go globalSessions.GC() - } - -* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: - - func init() { - globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) - go globalSessions.GC() - } - -* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): - - func init() { - globalSessions, _ = session.NewManager( - "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) - go globalSessions.GC() - } - -* Use **Cookie** as provider: - - func init() { - globalSessions, _ = session.NewManager( - "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) - go globalSessions.GC() - } - - -Finally in the handlerfunc you can use it like this - - func login(w http.ResponseWriter, r *http.Request) { - sess := globalSessions.SessionStart(w, r) - defer sess.SessionRelease(w) - username := sess.Get("username") - fmt.Println(username) - if r.Method == "GET" { - t, _ := template.ParseFiles("login.gtpl") - t.Execute(w, nil) - } else { - fmt.Println("username:", r.Form["username"]) - sess.Set("username", r.Form["username"]) - fmt.Println("password:", r.Form["password"]) - } - } - - -## How to write own provider? - -When you develop a web app, maybe you want to write own provider because you must meet the requirements. - -Writing a provider is easy. You only need to define two struct types -(Session and Provider), which satisfy the interface definition. -Maybe you will find the **memory** provider is a good example. - - type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data - } - - type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (SessionStore, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() - } - - -## LICENSE - -BSD License http://creativecommons.org/licenses/BSD/ diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go deleted file mode 100644 index 707d042c..00000000 --- a/session/couchbase/sess_couchbase.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package couchbase for session provider -// -// depend on github.com/couchbaselabs/go-couchbasee -// -// go install github.com/couchbaselabs/go-couchbase -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/couchbase" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package couchbase - -import ( - "net/http" - "strings" - "sync" - - couchbase "github.com/couchbase/go-couchbase" - - "github.com/astaxie/beego/session" -) - -var couchbpder = &Provider{} - -// SessionStore store each session -type SessionStore struct { - b *couchbase.Bucket - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Provider couchabse provided -type Provider struct { - maxlifetime int64 - savePath string - pool string - bucket string - b *couchbase.Bucket -} - -// Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { - cs.lock.Lock() - defer cs.lock.Unlock() - cs.values[key] = value - return nil -} - -// Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { - cs.lock.RLock() - defer cs.lock.RUnlock() - if v, ok := cs.values[key]; ok { - return v - } - return nil -} - -// Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { - cs.lock.Lock() - defer cs.lock.Unlock() - delete(cs.values, key) - return nil -} - -// Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { - cs.lock.Lock() - defer cs.lock.Unlock() - cs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { - return cs.sid -} - -// SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { - defer cs.b.Close() - - bo, err := session.EncodeGob(cs.values) - if err != nil { - return - } - - cs.b.Set(cs.sid, int(cs.maxlifetime), bo) -} - -func (cp *Provider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.savePath) - if err != nil { - return nil - } - - pool, err := c.GetPool(cp.pool) - if err != nil { - return nil - } - - bucket, err := pool.GetBucket(cp.bucket) - if err != nil { - return nil - } - - return bucket -} - -// SessionInit init couchbase session -// savepath like couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { - cp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - cp.savePath = configs[0] - } - if len(configs) > 1 { - cp.pool = configs[1] - } - if len(configs) > 2 { - cp.bucket = configs[2] - } - - return nil -} - -// SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { - cp.b = cp.getBucket() - - var ( - kv map[interface{}]interface{} - err error - doc []byte - ) - - err = cp.b.Get(sid, &doc) - if err != nil { - return nil, err - } else if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} - return cs, nil -} - -// SessionExist Check couchbase session exist. -// it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) bool { - cp.b = cp.getBucket() - defer cp.b.Close() - - var doc []byte - - if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false - } - return true -} - -// SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - cp.b = cp.getBucket() - - var doc []byte - if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { - cp.b.Set(sid, int(cp.maxlifetime), "") - } else { - err := cp.b.Delete(oldsid) - if err != nil { - return nil, err - } - _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) - } - - err := cp.b.Get(sid, &doc) - if err != nil { - return nil, err - } - var kv map[interface{}]interface{} - if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} - return cs, nil -} - -// SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { - cp.b = cp.getBucket() - defer cp.b.Close() - - cp.b.Delete(sid) - return nil -} - -// SessionGC Recycle -func (cp *Provider) SessionGC() { -} - -// SessionAll return all active session -func (cp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("couchbase", couchbpder) -} diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go deleted file mode 100644 index ee81df67..00000000 --- a/session/ledis/ledis_session.go +++ /dev/null @@ -1,173 +0,0 @@ -// Package ledis provide session Provider -package ledis - -import ( - "net/http" - "strconv" - "strings" - "sync" - - "github.com/ledisdb/ledisdb/config" - "github.com/ledisdb/ledisdb/ledis" - - "github.com/astaxie/beego/session" -) - -var ( - ledispder = &Provider{} - c *ledis.DB -) - -// SessionStore ledis session store -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values[key] = value - return nil -} - -// Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { - ls.lock.RLock() - defer ls.lock.RUnlock() - if v, ok := ls.values[key]; ok { - return v - } - return nil -} - -// Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - delete(ls.values, key) - return nil -} - -// Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get ledis session id -func (ls *SessionStore) SessionID() string { - return ls.sid -} - -// SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(ls.values) - if err != nil { - return - } - c.Set([]byte(ls.sid), b) - c.Expire([]byte(ls.sid), ls.maxlifetime) -} - -// Provider ledis session provider -type Provider struct { - maxlifetime int64 - savePath string - db int -} - -// SessionInit init ledis session -// savepath like ledis server saveDataPath,pool size -// e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { - var err error - lp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) == 1 { - lp.savePath = configs[0] - } else if len(configs) == 2 { - lp.savePath = configs[0] - lp.db, err = strconv.Atoi(configs[1]) - if err != nil { - return err - } - } - cfg := new(config.Config) - cfg.DataDir = lp.savePath - - var ledisInstance *ledis.Ledis - ledisInstance, err = ledis.Open(cfg) - if err != nil { - return err - } - c, err = ledisInstance.Select(lp.db) - return err -} - -// SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { - var ( - kv map[interface{}]interface{} - err error - ) - - kvs, _ := c.Get([]byte(sid)) - - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob(kvs); err != nil { - return nil, err - } - } - - ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} - return ls, nil -} - -// SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) bool { - count, _ := c.Exists([]byte(sid)) - return count != 0 -} - -// SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - count, _ := c.Exists([]byte(sid)) - if count == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set([]byte(sid), []byte("")) - c.Expire([]byte(sid), lp.maxlifetime) - } else { - data, _ := c.Get([]byte(oldsid)) - c.Set([]byte(sid), data) - c.Expire([]byte(sid), lp.maxlifetime) - } - return lp.SessionRead(sid) -} - -// SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { - c.Del([]byte(sid)) - return nil -} - -// SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { -} - -// SessionAll return all active session -func (lp *Provider) SessionAll() int { - return 0 -} -func init() { - session.Register("ledis", ledispder) -} diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go deleted file mode 100644 index 85a2d815..00000000 --- a/session/memcache/sess_memcache.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memcache for session provider -// -// depend on github.com/bradfitz/gomemcache/memcache -// -// go install github.com/bradfitz/gomemcache/memcache -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/memcache" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package memcache - -import ( - "net/http" - "strings" - "sync" - - "github.com/astaxie/beego/session" - - "github.com/bradfitz/gomemcache/memcache" -) - -var mempder = &MemProvider{} -var client *memcache.Client - -// SessionStore memcache session store -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get memcache session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} - client.Set(&item) -} - -// MemProvider memcache session provider -type MemProvider struct { - maxlifetime int64 - conninfo []string - poolsize int - password string -} - -// SessionInit init memcache session -// savepath like -// e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - rp.conninfo = strings.Split(savePath, ";") - client = memcache.New(rp.conninfo...) - return nil -} - -// SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { - if client == nil { - if err := rp.connectInit(); err != nil { - return nil, err - } - } - item, err := client.Get(sid) - if err != nil { - if err == memcache.ErrCacheMiss { - rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} - return rs, nil - } - return nil, err - } - var kv map[interface{}]interface{} - if len(item.Value) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(item.Value) - if err != nil { - return nil, err - } - } - rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) bool { - if client == nil { - if err := rp.connectInit(); err != nil { - return false - } - } - if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - if client == nil { - if err := rp.connectInit(); err != nil { - return nil, err - } - } - var contain []byte - if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - item.Key = sid - item.Value = []byte("") - item.Expiration = int32(rp.maxlifetime) - client.Set(item) - } else { - client.Delete(oldsid) - item.Key = sid - item.Expiration = int32(rp.maxlifetime) - client.Set(item) - contain = item.Value - } - - var kv map[interface{}]interface{} - if len(contain) == 0 { - kv = make(map[interface{}]interface{}) - } else { - var err error - kv, err = session.DecodeGob(contain) - if err != nil { - return nil, err - } - } - - rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { - if client == nil { - if err := rp.connectInit(); err != nil { - return err - } - } - - return client.Delete(sid) -} - -func (rp *MemProvider) connectInit() error { - client = memcache.New(rp.conninfo...) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { - return 0 -} - -func init() { - session.Register("memcache", mempder) -} diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go deleted file mode 100644 index 301353ab..00000000 --- a/session/mysql/sess_mysql.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package mysql for session provider -// -// depends on github.com/go-sql-driver/mysql: -// -// go install github.com/go-sql-driver/mysql -// -// mysql session support need create table as sql: -// CREATE TABLE `session` ( -// `session_key` char(64) NOT NULL, -// `session_data` blob, -// `session_expiry` int(11) unsigned NOT NULL, -// PRIMARY KEY (`session_key`) -// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/mysql" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package mysql - -import ( - "database/sql" - "net/http" - "sync" - "time" - - "github.com/astaxie/beego/session" - // import mysql driver - _ "github.com/go-sql-driver/mysql" -) - -var ( - // TableName store the session in MySQL - TableName = "session" - mysqlpder = &Provider{} -) - -// SessionStore mysql session store -type SessionStore struct { - c *sql.DB - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value in mysql session. -// it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush clear all values in mysql session -func (st *SessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { - return st.sid -} - -// SessionRelease save mysql session values to database. -// must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { - defer st.c.Close() - b, err := session.EncodeGob(st.values) - if err != nil { - return - } - st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", - b, time.Now().Unix(), st.sid) -} - -// Provider mysql session provider -type Provider struct { - maxlifetime int64 - savePath string -} - -// connect to mysql -func (mp *Provider) connectInit() *sql.DB { - db, e := sql.Open("mysql", mp.savePath) - if e != nil { - return nil - } - return db -} - -// SessionInit init mysql session. -// savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { - mp.maxlifetime = maxlifetime - mp.savePath = savePath - return nil -} - -// SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", - sid, "", time.Now().Unix()) - } - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) bool { - c := mp.connectInit() - defer c.Close() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - return err != sql.ErrNoRows -} - -// SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) - } - c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { - c := mp.connectInit() - c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) - c.Close() - return nil -} - -// SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { - c := mp.connectInit() - c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) - c.Close() -} - -// SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { - c := mp.connectInit() - defer c.Close() - var total int - err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) - if err != nil { - return 0 - } - return total -} - -func init() { - session.Register("mysql", mysqlpder) -} diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go deleted file mode 100644 index 0b8b9645..00000000 --- a/session/postgres/sess_postgresql.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package postgres for session provider -// -// depends on github.com/lib/pq: -// -// go install github.com/lib/pq -// -// -// needs this table in your database: -// -// CREATE TABLE session ( -// session_key char(64) NOT NULL, -// session_data bytea, -// session_expiry timestamp NOT NULL, -// CONSTRAINT session_key PRIMARY KEY(session_key) -// ); -// -// will be activated with these settings in app.conf: -// -// SessionOn = true -// SessionProvider = postgresql -// SessionSavePath = "user=a password=b dbname=c sslmode=disable" -// SessionName = session -// -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/postgresql" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package postgres - -import ( - "database/sql" - "net/http" - "sync" - "time" - - "github.com/astaxie/beego/session" - // import postgresql Driver - _ "github.com/lib/pq" -) - -var postgresqlpder = &Provider{} - -// SessionStore postgresql session store -type SessionStore struct { - c *sql.DB - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value in postgresql session. -// it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { - return st.sid -} - -// SessionRelease save postgresql session values to database. -// must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { - defer st.c.Close() - b, err := session.EncodeGob(st.values) - if err != nil { - return - } - st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", - b, time.Now().Format(time.RFC3339), st.sid) - -} - -// Provider postgresql session provider -type Provider struct { - maxlifetime int64 - savePath string -} - -// connect to postgresql -func (mp *Provider) connectInit() *sql.DB { - db, e := sql.Open("postgres", mp.savePath) - if e != nil { - return nil - } - return db -} - -// SessionInit init postgresql session. -// savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { - mp.maxlifetime = maxlifetime - mp.savePath = savePath - return nil -} - -// SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=$1", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", - sid, "", time.Now().Format(time.RFC3339)) - - if err != nil { - return nil, err - } - } else if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) bool { - c := mp.connectInit() - defer c.Close() - row := c.QueryRow("select session_data from session where session_key=$1", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - return err != sql.ErrNoRows -} - -// SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=$1", oldsid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", - oldsid, "", time.Now().Format(time.RFC3339)) - } - c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid) - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { - c := mp.connectInit() - c.Exec("DELETE FROM session where session_key=$1", sid) - c.Close() - return nil -} - -// SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { - c := mp.connectInit() - c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) - c.Close() -} - -// SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { - c := mp.connectInit() - defer c.Close() - var total int - err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) - if err != nil { - return 0 - } - return total -} - -func init() { - session.Register("postgresql", postgresqlpder) -} diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go deleted file mode 100644 index 5c382d61..00000000 --- a/session/redis/sess_redis.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis - -import ( - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/session" - - "github.com/gomodule/redigo/redis" -) - -var redispder = &Provider{} - -// MaxPoolSize redis max pool size -var MaxPoolSize = 100 - -// SessionStore redis session store -type SessionStore struct { - p *redis.Pool - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p.Get() - defer c.Close() - c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) -} - -// Provider redis session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Pool -} - -// SessionInit init redis session -// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second -// e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - var idleTimeout time.Duration = 0 - if len(configs) > 4 { - timeout, err := strconv.Atoi(configs[4]) - if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second - } - } - rp.poollist = &redis.Pool{ - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", rp.savePath) - if err != nil { - return nil, err - } - if rp.password != "" { - if _, err = c.Do("AUTH", rp.password); err != nil { - c.Close() - return nil, err - } - } - // some redis proxy such as twemproxy is not support select command - if rp.dbNum > 0 { - _, err = c.Do("SELECT", rp.dbNum) - if err != nil { - c.Close() - return nil, err - } - } - return c, err - }, - MaxIdle: rp.poolsize, - } - - rp.poollist.IdleTimeout = idleTimeout - - return rp.poollist.Get().Err() -} - -// SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - var kv map[interface{}]interface{} - - kvs, err := redis.String(c.Do("GET", sid)) - if err != nil && err != redis.ErrNil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist.Get() - defer c.Close() - - if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Do("SET", sid, "", "EX", rp.maxlifetime) - } else { - c.Do("RENAME", oldsid, sid) - c.Do("EXPIRE", sid, rp.maxlifetime) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist.Get() - defer c.Close() - - c.Do("DEL", sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis", redispder) -} diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go deleted file mode 100644 index 262fa2e3..00000000 --- a/session/redis_cluster/redis_cluster.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_cluster" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis_cluster - -import ( - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" - "net/http" - "strconv" - "strings" - "sync" - "time" -) - -var redispder = &Provider{} - -// MaxPoolSize redis_cluster max pool size -var MaxPoolSize = 1000 - -// SessionStore redis_cluster session store -type SessionStore struct { - p *rediss.ClusterClient - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) -} - -// Provider redis_cluster session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *rediss.ClusterClient -} - -// SessionInit init redis_cluster session -// savepath like redis server addr,pool size,password,dbnum -// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - - rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - }) - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != rediss.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_cluster", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go deleted file mode 100644 index 6ecb2977..00000000 --- a/session/redis_sentinel/sess_redis_sentinel.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_sentinel" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) -// go globalSessions.GC() -// } -// -// more detail about params: please check the notes on the function SessionInit in this package -package redis_sentinel - -import ( - "github.com/astaxie/beego/session" - "github.com/go-redis/redis" - "net/http" - "strconv" - "strings" - "sync" - "time" -) - -var redispder = &Provider{} - -// DefaultPoolSize redis_sentinel default pool size -var DefaultPoolSize = 100 - -// SessionStore redis_sentinel session store -type SessionStore struct { - p *redis.Client - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) -} - -// Provider redis_sentinel session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Client - masterName string -} - -// SessionInit init redis_sentinel session -// savepath like redis sentinel addr,pool size,password,dbnum,masterName -// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = DefaultPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = DefaultPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - if len(configs) > 4 { - if configs[4] != "" { - rp.masterName = configs[4] - } else { - rp.masterName = "mymaster" - } - } else { - rp.masterName = "mymaster" - } - - rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - }) - - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != redis.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_sentinel", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel_test.go b/session/redis_sentinel/sess_redis_sentinel_test.go deleted file mode 100644 index fd4155c6..00000000 --- a/session/redis_sentinel/sess_redis_sentinel_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package redis_sentinel - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/astaxie/beego/session" -) - -func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } - globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) - if e != nil { - t.Log(e) - return - } - //todo test if e==nil - go globalSessions.GC() - - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start failed:", err) - } - defer sess.SessionRelease(w) - - // SET AND GET - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set username failed:", err) - } - username := sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - - // DELETE - err = sess.Delete("username") - if err != nil { - t.Fatal("delete username failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("delete username failed") - } - - // FLUSH - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set failed:", err) - } - err = sess.Set("password", "1qaz2wsx") - if err != nil { - t.Fatal("set failed:", err) - } - username = sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - password := sess.Get("password") - if password != "1qaz2wsx" { - t.Fatal("get password failed") - } - err = sess.Flush() - if err != nil { - t.Fatal("flush failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("flush failed") - } - password = sess.Get("password") - if password != nil { - t.Fatal("flush failed") - } - - sess.SessionRelease(w) - -} diff --git a/session/sess_cookie.go b/session/sess_cookie.go deleted file mode 100644 index 6ad5debc..00000000 --- a/session/sess_cookie.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "crypto/aes" - "crypto/cipher" - "encoding/json" - "net/http" - "net/url" - "sync" -) - -var cookiepder = &CookieProvider{} - -// CookieSessionStore Cookie SessionStore -type CookieSessionStore struct { - sid string - values map[interface{}]interface{} // session data - lock sync.RWMutex -} - -// Set value to cookie session. -// the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { - return st.sid -} - -// SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { - st.lock.Lock() - encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) - st.lock.Unlock() - if err == nil { - cookie := &http.Cookie{Name: cookiepder.config.CookieName, - Value: url.QueryEscape(encodedCookie), - Path: "/", - HttpOnly: true, - Secure: cookiepder.config.Secure, - MaxAge: cookiepder.config.Maxage} - http.SetCookie(w, cookie) - } -} - -type cookieConfig struct { - SecurityKey string `json:"securityKey"` - BlockKey string `json:"blockKey"` - SecurityName string `json:"securityName"` - CookieName string `json:"cookieName"` - Secure bool `json:"secure"` - Maxage int `json:"maxage"` -} - -// CookieProvider Cookie session provider -type CookieProvider struct { - maxlifetime int64 - config *cookieConfig - block cipher.Block -} - -// SessionInit Init cookie session provider with max lifetime and config json. -// maxlifetime is ignored. -// json config: -// securityKey - hash string -// blockKey - gob encode hash string. it's saved as aes crypto. -// securityName - recognized name in encoded cookie string -// cookieName - cookie name -// maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { - pder.config = &cookieConfig{} - err := json.Unmarshal([]byte(config), pder.config) - if err != nil { - return err - } - if pder.config.BlockKey == "" { - pder.config.BlockKey = string(generateRandomKey(16)) - } - if pder.config.SecurityName == "" { - pder.config.SecurityName = string(generateRandomKey(20)) - } - pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) - if err != nil { - return err - } - pder.maxlifetime = maxlifetime - return nil -} - -// SessionRead Get SessionStore in cooke. -// decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { - maps, _ := decodeCookie(pder.block, - pder.config.SecurityKey, - pder.config.SecurityName, - sid, pder.maxlifetime) - if maps == nil { - maps = make(map[interface{}]interface{}) - } - rs := &CookieSessionStore{sid: sid, values: maps} - return rs, nil -} - -// SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) bool { - return true -} - -// SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - return nil, nil -} - -// SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { - return nil -} - -// SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { -} - -// SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { - return 0 -} - -// SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { - return nil -} - -func init() { - Register("cookie", cookiepder) -} diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go deleted file mode 100644 index b6726005..00000000 --- a/session/sess_cookie_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestCookie(t *testing.T) { - config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, err := NewManager("cookie", conf) - if err != nil { - t.Fatal("init cookie session err", err) - } - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("set error,", err) - } - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set error,", err) - } - if username := sess.Get("username"); username != "astaxie" { - t.Fatal("get username error") - } - sess.SessionRelease(w) - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { - t.Fatal("setcookie error") - } else { - parts := strings.Split(strings.TrimSpace(cookiestr), ";") - for k, v := range parts { - nameval := strings.Split(v, "=") - if k == 0 && nameval[0] != "gosessionid" { - t.Fatal("error") - } - } - } -} - -func TestDestorySessionCookie(t *testing.T) { - config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, err := NewManager("cookie", conf) - if err != nil { - t.Fatal("init cookie session err", err) - } - - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - session, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start err,", err) - } - - // request again ,will get same sesssion id . - r1, _ := http.NewRequest("GET", "/", nil) - r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) - w = httptest.NewRecorder() - newSession, err := globalSessions.SessionStart(w, r1) - if err != nil { - t.Fatal("session start err,", err) - } - if newSession.SessionID() != session.SessionID() { - t.Fatal("get cookie session id is not the same again.") - } - - // After destroy session , will get a new session id . - globalSessions.SessionDestroy(w, r1) - r2, _ := http.NewRequest("GET", "/", nil) - r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) - - w = httptest.NewRecorder() - newSession, err = globalSessions.SessionStart(w, r2) - if err != nil { - t.Fatal("session start error") - } - if newSession.SessionID() == session.SessionID() { - t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") - } -} diff --git a/session/sess_file.go b/session/sess_file.go deleted file mode 100644 index 47ad54a7..00000000 --- a/session/sess_file.go +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "errors" - "fmt" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strings" - "sync" - "time" -) - -var ( - filepder = &FileProvider{} - gcmaxlifetime int64 -) - -// FileSessionStore File session store -type FileSessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { - fs.lock.Lock() - defer fs.lock.Unlock() - fs.values[key] = value - return nil -} - -// Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { - fs.lock.RLock() - defer fs.lock.RUnlock() - if v, ok := fs.values[key]; ok { - return v - } - return nil -} - -// Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { - fs.lock.Lock() - defer fs.lock.Unlock() - delete(fs.values, key) - return nil -} - -// Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { - fs.lock.Lock() - defer fs.lock.Unlock() - fs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { - return fs.sid -} - -// SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { - filepder.lock.Lock() - defer filepder.lock.Unlock() - b, err := EncodeGob(fs.values) - if err != nil { - SLogger.Println(err) - return - } - _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) - if err != nil { - SLogger.Println(err) - return - } - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) - if err != nil { - SLogger.Println(err) - return - } - } else { - return - } - f.Truncate(0) - f.Seek(0, 0) - f.Write(b) - f.Close() -} - -// FileProvider File session provider -type FileProvider struct { - lock sync.RWMutex - maxlifetime int64 - savePath string -} - -// SessionInit Init file session provider. -// savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { - fp.maxlifetime = maxlifetime - fp.savePath = savePath - return nil -} - -// SessionRead Read file session by sid. -// if file is not exist, create it. -// the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { - invalidChars := "./" - if strings.ContainsAny(sid, invalidChars) { - return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) - } - if len(sid) < 2 { - return nil, errors.New("length of the sid is less than 2") - } - filepder.lock.Lock() - defer filepder.lock.Unlock() - - err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) - if err != nil { - SLogger.Println(err.Error()) - } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { - return nil, err - } - - defer f.Close() - - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) - var kv map[interface{}]interface{} - b, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) - if err != nil { - return nil, err - } - } - - ss := &FileSessionStore{sid: sid, values: kv} - return ss, nil -} - -// SessionExist Check file session exist. -// it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) bool { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - if len(sid) < 2 { - SLogger.Println("min length of session id is 2", sid) - return false - } - - _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return err == nil -} - -// SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { - filepder.lock.Lock() - defer filepder.lock.Unlock() - os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return nil -} - -// SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - gcmaxlifetime = fp.maxlifetime - filepath.Walk(fp.savePath, gcpath) -} - -// SessionAll Get active file session number. -// it walks save path to count files. -func (fp *FileProvider) SessionAll() int { - a := &activeSession{} - err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { - return a.visit(path, f, err) - }) - if err != nil { - SLogger.Printf("filepath.Walk() returned %v\n", err) - return 0 - } - return a.total -} - -// SessionRegenerate Generate new sid for file session. -// it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) - oldSidFile := path.Join(oldPath, oldsid) - newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) - newSidFile := path.Join(newPath, sid) - - // new sid file is exist - _, err := os.Stat(newSidFile) - if err == nil { - return nil, fmt.Errorf("newsid %s exist", newSidFile) - } - - err = os.MkdirAll(newPath, 0755) - if err != nil { - SLogger.Println(err.Error()) - } - - // if old sid file exist - // 1.read and parse file content - // 2.write content to new sid file - // 3.remove old sid file, change new sid file atime and ctime - // 4.return FileSessionStore - _, err = os.Stat(oldSidFile) - if err == nil { - b, err := ioutil.ReadFile(oldSidFile) - if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) - if err != nil { - return nil, err - } - } - - ioutil.WriteFile(newSidFile, b, 0777) - os.Remove(oldSidFile) - os.Chtimes(newSidFile, time.Now(), time.Now()) - ss := &FileSessionStore{sid: sid, values: kv} - return ss, nil - } - - // if old sid file not exist, just create new sid file and return - newf, err := os.Create(newSidFile) - if err != nil { - return nil, err - } - newf.Close() - ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} - return ss, nil -} - -// remove file in save path if expired -func gcpath(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - return nil - } - if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { - os.Remove(path) - } - return nil -} - -type activeSession struct { - total int -} - -func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { - if err != nil { - return err - } - if f.IsDir() { - return nil - } - as.total = as.total + 1 - return nil -} - -func init() { - Register("file", filepder) -} diff --git a/session/sess_file_test.go b/session/sess_file_test.go deleted file mode 100644 index 021c43fc..00000000 --- a/session/sess_file_test.go +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "fmt" - "os" - "sync" - "testing" - "time" -) - -const sid = "Session_id" -const sidNew = "Session_id_new" -const sessionPath = "./_session_runtime" - -var ( - mutex sync.Mutex -) - -func TestFileProvider_SessionInit(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - if fp.maxlifetime != 180 { - t.Error() - } - - if fp.savePath != sessionPath { - t.Error() - } -} - -func TestFileProvider_SessionExist(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - if fp.SessionExist(sid) { - t.Error() - } - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } -} - -func TestFileProvider_SessionExist2(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - if fp.SessionExist(sid) { - t.Error() - } - - if fp.SessionExist("") { - t.Error() - } - - if fp.SessionExist("1") { - t.Error() - } -} - -func TestFileProvider_SessionRead(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - s, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - _ = s.Set("sessionValue", 18975) - v := s.Get("sessionValue") - - if v.(int) != 18975 { - t.Error() - } -} - -func TestFileProvider_SessionRead1(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead("") - if err == nil { - t.Error(err) - } - - _, err = fp.SessionRead("1") - if err == nil { - t.Error(err) - } -} - -func TestFileProvider_SessionAll(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 546 - - for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - } - - if fp.SessionAll() != sessionCount { - t.Error() - } -} - -func TestFileProvider_SessionRegenerate(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } - - _, err = fp.SessionRegenerate(sid, sidNew) - if err != nil { - t.Error(err) - } - - if fp.SessionExist(sid) { - t.Error() - } - - if !fp.SessionExist(sidNew) { - t.Error() - } -} - -func TestFileProvider_SessionDestroy(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } - - err = fp.SessionDestroy(sid) - if err != nil { - t.Error(err) - } - - if fp.SessionExist(sid) { - t.Error() - } -} - -func TestFileProvider_SessionGC(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(1, sessionPath) - - sessionCount := 412 - - for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - } - - time.Sleep(2 * time.Second) - - fp.SessionGC() - if fp.SessionAll() != 0 { - t.Error() - } -} - -func TestFileSessionStore_Set(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - err := s.Set(i, i) - if err != nil { - t.Error(err) - } - } -} - -func TestFileSessionStore_Get(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) - - v := s.Get(i) - if v.(int) != i { - t.Error() - } - } -} - -func TestFileSessionStore_Delete(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - s, _ := fp.SessionRead(sid) - s.Set("1", 1) - - if s.Get("1") == nil { - t.Error() - } - - s.Delete("1") - - if s.Get("1") != nil { - t.Error() - } -} - -func TestFileSessionStore_Flush(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) - } - - _ = s.Flush() - - for i := 1; i <= sessionCount; i++ { - if s.Get(i) != nil { - t.Error() - } - } -} - -func TestFileSessionStore_SessionID(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 85 - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { - t.Error(err) - } - } -} - -func TestFileSessionStore_SessionRelease(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - filepder.savePath = sessionPath - sessionCount := 85 - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - s.Set(i, i) - s.SessionRelease(nil) - } - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - if s.Get(i).(int) != i { - t.Error() - } - } -} diff --git a/session/sess_mem.go b/session/sess_mem.go deleted file mode 100644 index 64d8b056..00000000 --- a/session/sess_mem.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "container/list" - "net/http" - "sync" - "time" -) - -var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} - -// MemSessionStore memory session store. -// it saved sessions in a map in memory. -type MemSessionStore struct { - sid string //session id - timeAccessed time.Time //last access time - value map[interface{}]interface{} //session store - lock sync.RWMutex -} - -// Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.value[key] = value - return nil -} - -// Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.value[key]; ok { - return v - } - return nil -} - -// Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.value, key) - return nil -} - -// Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.value = make(map[interface{}]interface{}) - return nil -} - -// SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { - return st.sid -} - -// SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { -} - -// MemProvider Implement the provider interface -type MemProvider struct { - lock sync.RWMutex // locker - sessions map[string]*list.Element // map in memory - list *list.List // for gc - maxlifetime int64 - savePath string -} - -// SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { - pder.maxlifetime = maxlifetime - pder.savePath = savePath - return nil -} - -// SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { - pder.lock.RLock() - if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) - pder.lock.RUnlock() - return element.Value.(*MemSessionStore), nil - } - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushFront(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil -} - -// SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) bool { - pder.lock.RLock() - defer pder.lock.RUnlock() - if _, ok := pder.sessions[sid]; ok { - return true - } - return false -} - -// SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - pder.lock.RLock() - if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) - pder.lock.RUnlock() - pder.lock.Lock() - element.Value.(*MemSessionStore).sid = sid - pder.sessions[sid] = element - delete(pder.sessions, oldsid) - pder.lock.Unlock() - return element.Value.(*MemSessionStore), nil - } - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushFront(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil -} - -// SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - delete(pder.sessions, sid) - pder.list.Remove(element) - return nil - } - return nil -} - -// SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { - pder.lock.RLock() - for { - element := pder.list.Back() - if element == nil { - break - } - if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { - pder.lock.RUnlock() - pder.lock.Lock() - pder.list.Remove(element) - delete(pder.sessions, element.Value.(*MemSessionStore).sid) - pder.lock.Unlock() - pder.lock.RLock() - } else { - break - } - } - pder.lock.RUnlock() -} - -// SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { - return pder.list.Len() -} - -// SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - element.Value.(*MemSessionStore).timeAccessed = time.Now() - pder.list.MoveToFront(element) - return nil - } - return nil -} - -func init() { - Register("memory", mempder) -} diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go deleted file mode 100644 index 2e8934b8..00000000 --- a/session/sess_mem_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestMem(t *testing.T) { - config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, _ := NewManager("memory", conf) - go globalSessions.GC() - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("set error,", err) - } - defer sess.SessionRelease(w) - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set error,", err) - } - if username := sess.Get("username"); username != "astaxie" { - t.Fatal("get username error") - } - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { - t.Fatal("setcookie error") - } else { - parts := strings.Split(strings.TrimSpace(cookiestr), ";") - for k, v := range parts { - nameval := strings.Split(v, "=") - if k == 0 && nameval[0] != "gosessionid" { - t.Fatal("error") - } - } - } -} diff --git a/session/sess_test.go b/session/sess_test.go deleted file mode 100644 index 906abec2..00000000 --- a/session/sess_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "crypto/aes" - "encoding/json" - "testing" -) - -func Test_gob(t *testing.T) { - a := make(map[interface{}]interface{}) - a["username"] = "astaxie" - a[12] = 234 - a["user"] = User{"asta", "xie"} - b, err := EncodeGob(a) - if err != nil { - t.Error(err) - } - c, err := DecodeGob(b) - if err != nil { - t.Error(err) - } - if len(c) == 0 { - t.Error("decodeGob empty") - } - if c["username"] != "astaxie" { - t.Error("decode string error") - } - if c[12] != 234 { - t.Error("decode int error") - } - if c["user"].(User).Username != "asta" { - t.Error("decode struct error") - } -} - -type User struct { - Username string - NickName string -} - -func TestGenerate(t *testing.T) { - str := generateRandomKey(20) - if len(str) != 20 { - t.Fatal("generate length is not equal to 20") - } -} - -func TestCookieEncodeDecode(t *testing.T) { - hashKey := "testhashKey" - blockkey := generateRandomKey(16) - block, err := aes.NewCipher(blockkey) - if err != nil { - t.Fatal("NewCipher:", err) - } - securityName := string(generateRandomKey(20)) - val := make(map[interface{}]interface{}) - val["name"] = "astaxie" - val["gender"] = "male" - str, err := encodeCookie(block, hashKey, securityName, val) - if err != nil { - t.Fatal("encodeCookie:", err) - } - dst, err := decodeCookie(block, hashKey, securityName, str, 3600) - if err != nil { - t.Fatal("decodeCookie", err) - } - if dst["name"] != "astaxie" { - t.Fatal("dst get map error") - } - if dst["gender"] != "male" { - t.Fatal("dst get map error") - } -} - -func TestParseConfig(t *testing.T) { - s := `{"cookieName":"gosessionid","gclifetime":3600}` - cf := new(ManagerConfig) - cf.EnableSetCookie = true - err := json.Unmarshal([]byte(s), cf) - if err != nil { - t.Fatal("parse json error,", err) - } - if cf.CookieName != "gosessionid" { - t.Fatal("parseconfig get cookiename error") - } - if cf.Gclifetime != 3600 { - t.Fatal("parseconfig get gclifetime error") - } - - cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - cf2 := new(ManagerConfig) - cf2.EnableSetCookie = true - err = json.Unmarshal([]byte(cc), cf2) - if err != nil { - t.Fatal("parse json error,", err) - } - if cf2.CookieName != "gosessionid" { - t.Fatal("parseconfig get cookiename error") - } - if cf2.Gclifetime != 3600 { - t.Fatal("parseconfig get gclifetime error") - } - if cf2.EnableSetCookie { - t.Fatal("parseconfig get enableSetCookie error") - } - cconfig := new(cookieConfig) - err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) - if err != nil { - t.Fatal("parse ProviderConfig err,", err) - } - if cconfig.CookieName != "gosessionid" { - t.Fatal("ProviderConfig get cookieName error") - } - if cconfig.SecurityKey != "beegocookiehashkey" { - t.Fatal("ProviderConfig get securityKey error") - } -} diff --git a/session/sess_utils.go b/session/sess_utils.go deleted file mode 100644 index 20915bb6..00000000 --- a/session/sess_utils.go +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "bytes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" - "encoding/gob" - "errors" - "fmt" - "io" - "strconv" - "time" - - "github.com/astaxie/beego/utils" -) - -func init() { - gob.Register([]interface{}{}) - gob.Register(map[int]interface{}{}) - gob.Register(map[string]interface{}{}) - gob.Register(map[interface{}]interface{}{}) - gob.Register(map[string]string{}) - gob.Register(map[int]string{}) - gob.Register(map[int]int{}) - gob.Register(map[int]int64{}) -} - -// EncodeGob encode the obj to gob -func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { - for _, v := range obj { - gob.Register(v) - } - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(obj) - if err != nil { - return []byte(""), err - } - return buf.Bytes(), nil -} - -// DecodeGob decode data to map -func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { - buf := bytes.NewBuffer(encoded) - dec := gob.NewDecoder(buf) - var out map[interface{}]interface{} - err := dec.Decode(&out) - if err != nil { - return nil, err - } - return out, nil -} - -// generateRandomKey creates a random key with the given strength. -func generateRandomKey(strength int) []byte { - k := make([]byte, strength) - if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { - return utils.RandomCreateBytes(strength) - } - return k -} - -// Encryption ----------------------------------------------------------------- - -// encrypt encrypts a value using the given block in counter mode. -// -// A random initialization vector (http://goo.gl/zF67k) with the length of the -// block size is prepended to the resulting ciphertext. -func encrypt(block cipher.Block, value []byte) ([]byte, error) { - iv := generateRandomKey(block.BlockSize()) - if iv == nil { - return nil, errors.New("encrypt: failed to generate random iv") - } - // Encrypt it. - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(value, value) - // Return iv + ciphertext. - return append(iv, value...), nil -} - -// decrypt decrypts a value using the given block in counter mode. -// -// The value to be decrypted must be prepended by a initialization vector -// (http://goo.gl/zF67k) with the length of the block size. -func decrypt(block cipher.Block, value []byte) ([]byte, error) { - size := block.BlockSize() - if len(value) > size { - // Extract iv. - iv := value[:size] - // Extract ciphertext. - value = value[size:] - // Decrypt it. - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(value, value) - return value, nil - } - return nil, errors.New("decrypt: the value could not be decrypted") -} - -func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { - var err error - var b []byte - // 1. EncodeGob. - if b, err = EncodeGob(value); err != nil { - return "", err - } - // 2. Encrypt (optional). - if b, err = encrypt(block, b); err != nil { - return "", err - } - b = encode(b) - // 3. Create MAC for "name|date|value". Extra pipe to be used later. - b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) - h := hmac.New(sha256.New, []byte(hashKey)) - h.Write(b) - sig := h.Sum(nil) - // Append mac, remove name. - b = append(b, sig...)[len(name)+1:] - // 4. Encode to base64. - b = encode(b) - // Done. - return string(b), nil -} - -func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { - // 1. Decode from base64. - b, err := decode([]byte(value)) - if err != nil { - return nil, err - } - // 2. Verify MAC. Value is "date|value|mac". - parts := bytes.SplitN(b, []byte("|"), 3) - if len(parts) != 3 { - return nil, errors.New("Decode: invalid value format") - } - - b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) - h := hmac.New(sha256.New, []byte(hashKey)) - h.Write(b) - sig := h.Sum(nil) - if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { - return nil, errors.New("Decode: the value is not valid") - } - // 3. Verify date ranges. - var t1 int64 - if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { - return nil, errors.New("Decode: invalid timestamp") - } - t2 := time.Now().UTC().Unix() - if t1 > t2 { - return nil, errors.New("Decode: timestamp is too new") - } - if t1 < t2-gcmaxlifetime { - return nil, errors.New("Decode: expired timestamp") - } - // 4. Decrypt (optional). - b, err = decode(parts[1]) - if err != nil { - return nil, err - } - if b, err = decrypt(block, b); err != nil { - return nil, err - } - // 5. DecodeGob. - dst, err := DecodeGob(b) - if err != nil { - return nil, err - } - return dst, nil -} - -// Encoding ------------------------------------------------------------------- - -// encode encodes a value using base64. -func encode(value []byte) []byte { - encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) - base64.URLEncoding.Encode(encoded, value) - return encoded -} - -// decode decodes a cookie using base64. -func decode(value []byte) ([]byte, error) { - decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) - b, err := base64.URLEncoding.Decode(decoded, value) - if err != nil { - return nil, err - } - return decoded[:b], nil -} diff --git a/session/session.go b/session/session.go deleted file mode 100644 index eb85360a..00000000 --- a/session/session.go +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package session provider -// -// Usage: -// import( -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package session - -import ( - "crypto/rand" - "encoding/hex" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/textproto" - "net/url" - "os" - "time" -) - -// Store contains all data for one session process with specific id. -type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data -} - -// Provider contains global session methods and saved SessionStores. -// it can operate a SessionStore by its id. -type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() -} - -var provides = make(map[string]Provider) - -// SLogger a helpful variable to log information about session -var SLogger = NewSessionLog(os.Stderr) - -// Register makes a session provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, provide Provider) { - if provide == nil { - panic("session: Register provide is nil") - } - if _, dup := provides[name]; dup { - panic("session: Register called twice for provider " + name) - } - provides[name] = provide -} - -//GetProvider -func GetProvider(name string) (Provider, error) { - provider, ok := provides[name] - if !ok { - return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) - } - return provider, nil -} - -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` -} - -// Manager contains Provider and its configuration. -type Manager struct { - provider Provider - config *ManagerConfig -} - -// NewManager Create new Manager with provider name and json config string. -// provider name: -// 1. cookie -// 2. file -// 3. memory -// 4. redis -// 5. mysql -// json config: -// 1. is https default false -// 2. hashfunc default sha1 -// 3. hashkey default beegosessionkey -// 4. maxage default is none -func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { - provider, ok := provides[provideName] - if !ok { - return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) - } - - if cf.Maxlifetime == 0 { - cf.Maxlifetime = cf.Gclifetime - } - - if cf.EnableSidInHTTPHeader { - if cf.SessionNameInHTTPHeader == "" { - panic(errors.New("SessionNameInHTTPHeader is empty")) - } - - strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) - if cf.SessionNameInHTTPHeader != strMimeHeader { - strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader - panic(errors.New(strErrMsg)) - } - } - - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) - if err != nil { - return nil, err - } - - if cf.SessionIDLength == 0 { - cf.SessionIDLength = 16 - } - - return &Manager{ - provider, - cf, - }, nil -} - -// GetProvider return current manager's provider -func (manager *Manager) GetProvider() Provider { - return manager.provider -} - -// getSid retrieves session identifier from HTTP Request. -// First try to retrieve id by reading from cookie, session cookie name is configurable, -// if not exist, then retrieve id from querying parameters. -// -// error is not nil when there is anything wrong. -// sid is empty when need to generate a new session id -// otherwise return an valid session id. -func (manager *Manager) getSid(r *http.Request) (string, error) { - cookie, errs := r.Cookie(manager.config.CookieName) - if errs != nil || cookie.Value == "" { - var sid string - if manager.config.EnableSidInURLQuery { - errs := r.ParseForm() - if errs != nil { - return "", errs - } - - sid = r.FormValue(manager.config.CookieName) - } - - // if not found in Cookie / param, then read it from request headers - if manager.config.EnableSidInHTTPHeader && sid == "" { - sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] - if isFound && len(sids) != 0 { - return sids[0], nil - } - } - - return sid, nil - } - - // HTTP Request contains cookie for sessionid info. - return url.QueryUnescape(cookie.Value) -} - -// SessionStart generate or read the session id from http request. -// if session id exists, return SessionStore with this id. -func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { - sid, errs := manager.getSid(r) - if errs != nil { - return nil, errs - } - - if sid != "" && manager.provider.SessionExist(sid) { - return manager.provider.SessionRead(sid) - } - - // Generate a new session - sid, errs = manager.sessionID() - if errs != nil { - return nil, errs - } - - session, err = manager.provider.SessionRead(sid) - if err != nil { - return nil, err - } - cookie := &http.Cookie{ - Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) - - if manager.config.EnableSidInHTTPHeader { - r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) - w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) - } - - return -} - -// SessionDestroy Destroy session by its id in http request cookie. -func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { - if manager.config.EnableSidInHTTPHeader { - r.Header.Del(manager.config.SessionNameInHTTPHeader) - w.Header().Del(manager.config.SessionNameInHTTPHeader) - } - - cookie, err := r.Cookie(manager.config.CookieName) - if err != nil || cookie.Value == "" { - return - } - - sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) - if manager.config.EnableSetCookie { - expiration := time.Now() - cookie = &http.Cookie{Name: manager.config.CookieName, - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Expires: expiration, - MaxAge: -1, - Domain: manager.config.Domain} - - http.SetCookie(w, cookie) - } -} - -// GetSessionStore Get SessionStore by its id. -func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) - return -} - -// GC Start session gc process. -// it can do gc in times after gc lifetime. -func (manager *Manager) GC() { - manager.provider.SessionGC() - time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) -} - -// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. -func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { - sid, err := manager.sessionID() - if err != nil { - return - } - cookie, err := r.Cookie(manager.config.CookieName) - if err != nil || cookie.Value == "" { - //delete old cookie - session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - } else { - oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) - cookie.Value = url.QueryEscape(sid) - cookie.HttpOnly = true - cookie.Path = "/" - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) - - if manager.config.EnableSidInHTTPHeader { - r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) - w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) - } - - return -} - -// GetActiveSession Get all active sessions count number. -func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() -} - -// SetSecure Set cookie with https. -func (manager *Manager) SetSecure(secure bool) { - manager.config.Secure = secure -} - -func (manager *Manager) sessionID() (string, error) { - b := make([]byte, manager.config.SessionIDLength) - n, err := rand.Read(b) - if n != len(b) || err != nil { - return "", fmt.Errorf("Could not successfully read from the system CSPRNG") - } - return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil -} - -// Set cookie with https. -func (manager *Manager) isSecure(req *http.Request) bool { - if !manager.config.Secure { - return false - } - if req.URL.Scheme != "" { - return req.URL.Scheme == "https" - } - if req.TLS == nil { - return false - } - return true -} - -// Log implement the log.Logger -type Log struct { - *log.Logger -} - -// NewSessionLog set io.Writer to create a Logger for session. -func NewSessionLog(out io.Writer) *Log { - sl := new(Log) - sl.Logger = log.New(out, "[SESSION]", 1e9) - return sl -} diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go deleted file mode 100644 index de0c6360..00000000 --- a/session/ssdb/sess_ssdb.go +++ /dev/null @@ -1,199 +0,0 @@ -package ssdb - -import ( - "errors" - "net/http" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/session" - "github.com/ssdb/gossdb/ssdb" -) - -var ssdbProvider = &Provider{} - -// Provider holds ssdb client and configs -type Provider struct { - client *ssdb.Client - host string - port int - maxLifetime int64 -} - -func (p *Provider) connectInit() error { - var err error - if p.host == "" || p.port == 0 { - return errors.New("SessionInit First") - } - p.client, err = ssdb.Connect(p.host, p.port) - return err -} - -// SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { - p.maxLifetime = maxLifetime - address := strings.Split(savePath, ":") - p.host = address[0] - - var err error - if p.port, err = strconv.Atoi(address[1]); err != nil { - return err - } - return p.connectInit() -} - -// SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { - if p.client == nil { - if err := p.connectInit(); err != nil { - return nil, err - } - } - var kv map[interface{}]interface{} - value, err := p.client.Get(sid) - if err != nil { - return nil, err - } - if value == nil || len(value.(string)) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(value.(string))) - if err != nil { - return nil, err - } - } - rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} - return rs, nil -} - -// SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) bool { - if p.client == nil { - if err := p.connectInit(); err != nil { - panic(err) - } - } - value, err := p.client.Get(sid) - if err != nil { - panic(err) - } - if value == nil || len(value.(string)) == 0 { - return false - } - return true -} - -// SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - //conn.Do("setx", key, v, ttl) - if p.client == nil { - if err := p.connectInit(); err != nil { - return nil, err - } - } - value, err := p.client.Get(oldsid) - if err != nil { - return nil, err - } - var kv map[interface{}]interface{} - if value == nil || len(value.(string)) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(value.(string))) - if err != nil { - return nil, err - } - _, err = p.client.Del(oldsid) - if err != nil { - return nil, err - } - } - _, e := p.client.Do("setx", sid, value, p.maxLifetime) - if e != nil { - return nil, e - } - rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} - return rs, nil -} - -// SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { - if p.client == nil { - if err := p.connectInit(); err != nil { - return err - } - } - _, err := p.client.Del(sid) - return err -} - -// SessionGC not implemented -func (p *Provider) SessionGC() { -} - -// SessionAll not implemented -func (p *Provider) SessionAll() int { - return 0 -} - -// SessionStore holds the session information which stored in ssdb -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxLifetime int64 - client *ssdb.Client -} - -// Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - s.values[key] = value - return nil -} - -// Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { - s.lock.Lock() - defer s.lock.Unlock() - if value, ok := s.values[key]; ok { - return value - } - return nil -} - -// Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - delete(s.values, key) - return nil -} - -// Flush delete all keys and values -func (s *SessionStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - s.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID return the sessionID -func (s *SessionStore) SessionID() string { - return s.sid -} - -// SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(s.values) - if err != nil { - return - } - s.client.Do("setx", s.sid, string(b), s.maxLifetime) -} - -func init() { - session.Register("ssdb", ssdbProvider) -} diff --git a/staticfile.go b/staticfile.go deleted file mode 100644 index e26776c5..00000000 --- a/staticfile.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "errors" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/hashicorp/golang-lru" -) - -var errNotStaticRequest = errors.New("request not a static file request") - -func serverStaticRouter(ctx *context.Context) { - if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { - return - } - - forbidden, filePath, fileInfo, err := lookupFile(ctx) - if err == errNotStaticRequest { - return - } - - if forbidden { - exception("403", ctx) - return - } - - if filePath == "" || fileInfo == nil { - if BConfig.RunMode == DEV { - logs.Warn("Can't find/open the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - if fileInfo.IsDir() { - requestURL := ctx.Input.URL() - if requestURL[len(requestURL)-1] != '/' { - redirectURL := requestURL + "/" - if ctx.Request.URL.RawQuery != "" { - redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery - } - ctx.Redirect(302, redirectURL) - } else { - //serveFile will list dir - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) - } - return - } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { - //over size file serve with http module - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) - return - } - - var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) - var acceptEncoding string - if enableCompress { - acceptEncoding = context.ParseEncoding(ctx.Request) - } - b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) - if err != nil { - if BConfig.RunMode == DEV { - logs.Warn("Can't compress the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - - if b { - ctx.Output.Header("Content-Encoding", n) - } else { - ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) - } - - http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) -} - -type serveContentHolder struct { - data []byte - modTime time.Time - size int64 - originSize int64 //original file size:to judge file changed - encoding string -} - -type serveContentReader struct { - *bytes.Reader -} - -var ( - staticFileLruCache *lru.Cache - lruLock sync.RWMutex -) - -func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { - if staticFileLruCache == nil { - //avoid lru cache error - if BConfig.WebConfig.StaticCacheFileNum >= 1 { - staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) - } else { - staticFileLruCache, _ = lru.New(1) - } - } - mapKey := acceptEncoding + ":" + filePath - lruLock.RLock() - var mapFile *serveContentHolder - if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { - mapFile = cacheItem.(*serveContentHolder) - } - lruLock.RUnlock() - if isOk(mapFile, fi) { - reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} - return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil - } - lruLock.Lock() - defer lruLock.Unlock() - if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { - mapFile = cacheItem.(*serveContentHolder) - } - if !isOk(mapFile, fi) { - file, err := os.Open(filePath) - if err != nil { - return false, "", nil, nil, err - } - defer file.Close() - var bufferWriter bytes.Buffer - _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) - if err != nil { - return false, "", nil, nil, err - } - mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} - if isOk(mapFile, fi) { - staticFileLruCache.Add(mapKey, mapFile) - } - } - - reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} - return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil -} - -func isOk(s *serveContentHolder, fi os.FileInfo) bool { - if s == nil { - return false - } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { - return false - } - return s.modTime == fi.ModTime() && s.originSize == fi.Size() -} - -// isStaticCompress detect static files -func isStaticCompress(filePath string) bool { - for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { - if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { - return true - } - } - return false -} - -// searchFile search the file by url path -// if none the static file prefix matches ,return notStaticRequestErr -func searchFile(ctx *context.Context) (string, os.FileInfo, error) { - requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) - // special processing : favicon.ico/robots.txt can be in any static dir - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - file := path.Join(".", requestPath) - if fi, _ := os.Stat(file); fi != nil { - return file, fi, nil - } - for _, staticDir := range BConfig.WebConfig.StaticDir { - filePath := path.Join(staticDir, requestPath) - if fi, _ := os.Stat(filePath); fi != nil { - return filePath, fi, nil - } - } - return "", nil, errNotStaticRequest - } - - for prefix, staticDir := range BConfig.WebConfig.StaticDir { - if !strings.Contains(requestPath, prefix) { - continue - } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { - continue - } - filePath := path.Join(staticDir, requestPath[len(prefix):]) - if fi, err := os.Stat(filePath); fi != nil { - return filePath, fi, err - } - } - return "", nil, errNotStaticRequest -} - -// lookupFile find the file to serve -// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) -// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex -func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { - fp, fi, err := searchFile(ctx) - if fp == "" || fi == nil { - return false, "", nil, err - } - if !fi.IsDir() { - return false, fp, fi, err - } - if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' { - ifp := filepath.Join(fp, "index.html") - if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { - return false, ifp, ifi, err - } - } - return !BConfig.WebConfig.DirectoryIndex, fp, fi, err -} diff --git a/staticfile_test.go b/staticfile_test.go deleted file mode 100644 index e46c13ec..00000000 --- a/staticfile_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package beego - -import ( - "bytes" - "compress/gzip" - "compress/zlib" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -var currentWorkDir, _ = os.Getwd() -var licenseFile = filepath.Join(currentWorkDir, "LICENSE") - -func testOpenFile(encoding string, content []byte, t *testing.T) { - fi, _ := os.Stat(licenseFile) - b, n, sch, reader, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Log(err) - t.Fail() - } - - t.Log("open static file encoding "+n, b) - - assetOpenFileAndContent(sch, reader, content, t) -} -func TestOpenStaticFile_1(t *testing.T) { - file, _ := os.Open(licenseFile) - content, _ := ioutil.ReadAll(file) - testOpenFile("", content, t) -} - -func TestOpenStaticFileGzip_1(t *testing.T) { - file, _ := os.Open(licenseFile) - var zipBuf bytes.Buffer - fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) - io.Copy(fileWriter, file) - fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) - - testOpenFile("gzip", content, t) -} -func TestOpenStaticFileDeflate_1(t *testing.T) { - file, _ := os.Open(licenseFile) - var zipBuf bytes.Buffer - fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) - io.Copy(fileWriter, file) - fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) - - testOpenFile("deflate", content, t) -} - -func TestStaticCacheWork(t *testing.T) { - encodings := []string{"", "gzip", "deflate"} - - fi, _ := os.Stat(licenseFile) - for _, encoding := range encodings { - _, _, first, _, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Error(err) - continue - } - - _, _, second, _, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Error(err) - continue - } - - address1 := fmt.Sprintf("%p", first) - address2 := fmt.Sprintf("%p", second) - if address1 != address2 { - t.Errorf("encoding '%v' can not hit cache", encoding) - } - } -} - -func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { - t.Log(sch.size, len(content)) - if sch.size != int64(len(content)) { - t.Log("static content file size not same") - t.Fail() - } - bs, _ := ioutil.ReadAll(reader) - for i, v := range content { - if v != bs[i] { - t.Log("content not same") - t.Fail() - } - } - if staticFileLruCache.Len() == 0 { - t.Log("men map is empty") - t.Fail() - } -} diff --git a/swagger/swagger.go b/swagger/swagger.go deleted file mode 100644 index a55676cd..00000000 --- a/swagger/swagger.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Swagger™ is a project used to describe and document RESTful APIs. -// -// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. -// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. - -// Package swagger struct definition -package swagger - -// Swagger list the resource -type Swagger struct { - SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` - Infos Information `json:"info" yaml:"info"` - Host string `json:"host,omitempty" yaml:"host,omitempty"` - BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Paths map[string]*Item `json:"paths" yaml:"paths"` - Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` - SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` - Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` -} - -// Information Provides metadata about the API. The metadata can be used by the clients if needed. -type Information struct { - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Version string `json:"version,omitempty" yaml:"version,omitempty"` - TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"` - - Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"` - License *License `json:"license,omitempty" yaml:"license,omitempty"` -} - -// Contact information for the exposed API. -type Contact struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` - EMail string `json:"email,omitempty" yaml:"email,omitempty"` -} - -// License information for the exposed API. -type License struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` -} - -// Item Describes the operations available on a single path. -type Item struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Get *Operation `json:"get,omitempty" yaml:"get,omitempty"` - Put *Operation `json:"put,omitempty" yaml:"put,omitempty"` - Post *Operation `json:"post,omitempty" yaml:"post,omitempty"` - Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"` - Options *Operation `json:"options,omitempty" yaml:"options,omitempty"` - Head *Operation `json:"head,omitempty" yaml:"head,omitempty"` - Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"` -} - -// Operation Describes a single API operation on a path. -type Operation struct { - Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` - Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` - Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` - Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` -} - -// Parameter Describes a single operation parameter. -type Parameter struct { - In string `json:"in,omitempty" yaml:"in,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Required bool `json:"required,omitempty" yaml:"required,omitempty"` - Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` - Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` -} - -// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". -// http://swagger.io/specification/#itemsObject -type ParameterItems struct { - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. - CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` - Default string `json:"default,omitempty" yaml:"default,omitempty"` -} - -// Schema Object allows the definition of input and output data types. -type Schema struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Required []string `json:"required,omitempty" yaml:"required,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` - Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` - Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` - Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` -} - -// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification -type Propertie struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` - Required []string `json:"required,omitempty" yaml:"required,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` - Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` - Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"` - AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"` -} - -// Response as they are returned from executing this operation. -type Response struct { - Description string `json:"description" yaml:"description"` - Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` -} - -// Security Allows the definition of a security scheme that can be used by the operations -type Security struct { - Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` - In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header". - Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". - AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"` - TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"` - Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. -} - -// Tag Allows adding meta data to a single tag that is used by the Operation Object -type Tag struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` -} - -// ExternalDocs include Additional external documentation -type ExternalDocs struct { - Description string `json:"description,omitempty" yaml:"description,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` -} diff --git a/template.go b/template.go deleted file mode 100644 index 69b178ca..00000000 --- a/template.go +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "html/template" - "io" - "io/ioutil" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var ( - beegoTplFuncMap = make(template.FuncMap) - beeViewPathTemplateLocked = false - // beeViewPathTemplates caching map and supported template file extensions per view - beeViewPathTemplates = make(map[string]map[string]*template.Template) - templatesLock sync.RWMutex - // beeTemplateExt stores the template extension which will build - beeTemplateExt = []string{"tpl", "html", "gohtml"} - // beeTemplatePreprocessors stores associations of extension -> preprocessor handler - beeTemplateEngines = map[string]templatePreProcessor{} - beeTemplateFS = defaultFSFunc -) - -// ExecuteTemplate applies the template with name to the specified data object, -// writing the output to wr. -// A template will be executed safely in parallel. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { - return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) -} - -// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, -// writing the output to wr. -// A template will be executed safely in parallel. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { - if BConfig.RunMode == DEV { - templatesLock.RLock() - defer templatesLock.RUnlock() - } - if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { - if t, ok := beeTemplates[name]; ok { - var err error - if t.Lookup(name) != nil { - err = t.ExecuteTemplate(wr, name, data) - } else { - err = t.Execute(wr, data) - } - if err != nil { - logs.Trace("template Execute err:", err) - } - return err - } - panic("can't find templatefile in the path:" + viewPath + "/" + name) - } - panic("Unknown view path:" + viewPath) -} - -func init() { - beegoTplFuncMap["dateformat"] = DateFormat - beegoTplFuncMap["date"] = Date - beegoTplFuncMap["compare"] = Compare - beegoTplFuncMap["compare_not"] = CompareNot - beegoTplFuncMap["not_nil"] = NotNil - beegoTplFuncMap["not_null"] = NotNil - beegoTplFuncMap["substr"] = Substr - beegoTplFuncMap["html2str"] = HTML2str - beegoTplFuncMap["str2html"] = Str2html - beegoTplFuncMap["htmlquote"] = Htmlquote - beegoTplFuncMap["htmlunquote"] = Htmlunquote - beegoTplFuncMap["renderform"] = RenderForm - beegoTplFuncMap["assets_js"] = AssetsJs - beegoTplFuncMap["assets_css"] = AssetsCSS - beegoTplFuncMap["config"] = GetConfig - beegoTplFuncMap["map_get"] = MapGet - - // Comparisons - beegoTplFuncMap["eq"] = eq // == - beegoTplFuncMap["ge"] = ge // >= - beegoTplFuncMap["gt"] = gt // > - beegoTplFuncMap["le"] = le // <= - beegoTplFuncMap["lt"] = lt // < - beegoTplFuncMap["ne"] = ne // != - - beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method -} - -// AddFuncMap let user to register a func in the template. -func AddFuncMap(key string, fn interface{}) error { - beegoTplFuncMap[key] = fn - return nil -} - -type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) - -type templateFile struct { - root string - files map[string][]string -} - -// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root). -// if tf.root="views" and -// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html" -// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html" -func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { - if f == nil { - return err - } - if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 { - return nil - } - if !HasTemplateExt(paths) { - return nil - } - - replace := strings.NewReplacer("\\", "/") - file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/") - subDir := filepath.Dir(file) - - tf.files[subDir] = append(tf.files[subDir], file) - return nil -} - -// HasTemplateExt return this path contains supported template extension of beego or not. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func HasTemplateExt(paths string) bool { - for _, v := range beeTemplateExt { - if strings.HasSuffix(paths, "."+v) { - return true - } - } - return false -} - -// AddTemplateExt add new extension for template. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddTemplateExt(ext string) { - for _, v := range beeTemplateExt { - if v == ext { - return - } - } - beeTemplateExt = append(beeTemplateExt, ext) -} - -// AddViewPath adds a new path to the supported view paths. -//Can later be used by setting a controller ViewPath to this folder -//will panic if called after beego.Run() -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddViewPath(viewPath string) error { - if beeViewPathTemplateLocked { - if _, exist := beeViewPathTemplates[viewPath]; exist { - return nil //Ignore if viewpath already exists - } - panic("Can not add new view paths after beego.Run()") - } - beeViewPathTemplates[viewPath] = make(map[string]*template.Template) - return BuildTemplate(viewPath) -} - -func lockViewPaths() { - beeViewPathTemplateLocked = true -} - -// BuildTemplate will build all template files in a directory. -// it makes beego can render any template file in view directory. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func BuildTemplate(dir string, files ...string) error { - var err error - fs := beeTemplateFS() - f, err := fs.Open(dir) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return errors.New("dir open err") - } - defer f.Close() - - beeTemplates, ok := beeViewPathTemplates[dir] - if !ok { - panic("Unknown view path: " + dir) - } - self := &templateFile{ - root: dir, - files: make(map[string][]string), - } - err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { - return self.visit(path, f, err) - }) - if err != nil { - fmt.Printf("Walk() returned %v\n", err) - return err - } - buildAllFiles := len(files) == 0 - for _, v := range self.files { - for _, file := range v { - if buildAllFiles || utils.InSlice(file, files) { - templatesLock.Lock() - ext := filepath.Ext(file) - var t *template.Template - if len(ext) == 0 { - t, err = getTemplate(self.root, fs, file, v...) - } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { - t, err = fn(self.root, file, beegoTplFuncMap) - } else { - t, err = getTemplate(self.root, fs, file, v...) - } - if err != nil { - logs.Error("parse template err:", file, err) - templatesLock.Unlock() - return err - } - beeTemplates[file] = t - templatesLock.Unlock() - } - } - } - return nil -} - -func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { - var fileAbsPath string - var rParent string - var err error - if strings.HasPrefix(file, "../") { - rParent = filepath.Join(filepath.Dir(parent), file) - fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) - } else { - rParent = file - fileAbsPath = filepath.Join(root, file) - } - f, err := fs.Open(fileAbsPath) - if err != nil { - panic("can't find template file:" + file) - } - defer f.Close() - data, err := ioutil.ReadAll(f) - if err != nil { - return nil, [][]string{}, err - } - t, err = t.New(file).Parse(string(data)) - if err != nil { - return nil, [][]string{}, err - } - reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") - allSub := reg.FindAllStringSubmatch(string(data), -1) - for _, m := range allSub { - if len(m) == 2 { - tl := t.Lookup(m[1]) - if tl != nil { - continue - } - if !HasTemplateExt(m[1]) { - continue - } - _, _, err = getTplDeep(root, fs, m[1], rParent, t) - if err != nil { - return nil, [][]string{}, err - } - } - } - return t, allSub, nil -} - -func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { - t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) - var subMods [][]string - t, subMods, err = getTplDeep(root, fs, file, "", t) - if err != nil { - return nil, err - } - t, err = _getTemplate(t, root, fs, subMods, others...) - - if err != nil { - return nil, err - } - return -} - -func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { - t = t0 - for _, m := range subMods { - if len(m) == 2 { - tpl := t.Lookup(m[1]) - if tpl != nil { - continue - } - //first check filename - for _, otherFile := range others { - if otherFile == m[1] { - var subMods1 [][]string - t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) - if err != nil { - logs.Trace("template parse file err:", err) - } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, fs, subMods1, others...) - } - break - } - } - //second check define - for _, otherFile := range others { - var data []byte - fileAbsPath := filepath.Join(root, otherFile) - f, err := fs.Open(fileAbsPath) - if err != nil { - f.Close() - logs.Trace("template file parse error, not success open file:", err) - continue - } - data, err = ioutil.ReadAll(f) - f.Close() - if err != nil { - logs.Trace("template file parse error, not success read file:", err) - continue - } - reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") - allSub := reg.FindAllStringSubmatch(string(data), -1) - for _, sub := range allSub { - if len(sub) == 2 && sub[1] == m[1] { - var subMods1 [][]string - t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) - if err != nil { - logs.Trace("template parse file err:", err) - } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, fs, subMods1, others...) - if err != nil { - logs.Trace("template parse file err:", err) - } - } - break - } - } - } - } - - } - return -} - -type templateFSFunc func() http.FileSystem - -func defaultFSFunc() http.FileSystem { - return FileSystem{} -} - -// SetTemplateFSFunc set default filesystem function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetTemplateFSFunc(fnt templateFSFunc) { - beeTemplateFS = fnt -} - -// SetViewsPath sets view directory path in beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetViewsPath(path string) *App { - BConfig.WebConfig.ViewsPath = path - return BeeApp -} - -// SetStaticPath sets static directory path and proper url pattern in beego application. -// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetStaticPath(url string, path string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - if url != "/" { - url = strings.TrimRight(url, "/") - } - BConfig.WebConfig.StaticDir[url] = path - return BeeApp -} - -// DelStaticPath removes the static folder setting in this url pattern in beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DelStaticPath(url string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - if url != "/" { - url = strings.TrimRight(url, "/") - } - delete(BConfig.WebConfig.StaticDir, url) - return BeeApp -} - -// AddTemplateEngine add a new templatePreProcessor which support extension -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddTemplateEngine(extension string, fn templatePreProcessor) *App { - AddTemplateExt(extension) - beeTemplateEngines[extension] = fn - return BeeApp -} diff --git a/template_test.go b/template_test.go deleted file mode 100644 index bde9c100..00000000 --- a/template_test.go +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "github.com/astaxie/beego/test" - "github.com/elazarl/go-bindata-assetfs" - "net/http" - "os" - "path/filepath" - "testing" -) - -var header = `{{define "header"}} -

Hello, astaxie!

-{{end}}` - -var index = ` - - - beego welcome template - - -{{template "block"}} -{{template "header"}} -{{template "blocks/block.tpl"}} - - -` - -var block = `{{define "block"}} -

Hello, blocks!

-{{end}}` - -func TestTemplate(t *testing.T) { - dir := "_beeTmp" - files := []string{ - "header.tpl", - "index.tpl", - "blocks/block.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(header) - } else if k == 1 { - f.WriteString(index) - } else if k == 2 { - f.WriteString(block) - } - - f.Close() - } - } - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 3 { - t.Fatalf("should be 3 but got %v", len(beeTemplates)) - } - if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil { - t.Fatal(err) - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -var menu = ` -` -var user = ` - - - beego welcome template - - -{{template "../public/menu.tpl"}} - - -` - -func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" - - //Just add dir to known viewPaths - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - - files := []string{ - "easyui/public/menu.tpl", - "easyui/rbac/user.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(menu) - } else if k == 1 { - f.WriteString(user) - } - f.Close() - } - } - if err := BuildTemplate(dir, files[1]); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { - t.Fatal(err) - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -var add = `{{ template "layout_blog.tpl" . }} -{{ define "css" }} - -{{ end}} - - -{{ define "content" }} -

{{ .Title }}

-

This is SomeVar: {{ .SomeVar }}

-{{ end }} - -{{ define "js" }} - -{{ end}}` - -var layoutBlog = ` - - - Lin Li - - - - - {{ block "css" . }}{{ end }} - - - -
- {{ block "content" . }}{{ end }} -
- - - {{ block "js" . }}{{ end }} - -` - -var output = ` - - - Lin Li - - - - - - - - - - -
- -

Hello

-

This is SomeVar: val

- -
- - - - - - - - - - - - -` - -func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" - files := []string{ - "add.tpl", - "layout_blog.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(add) - } else if k == 1 { - f.WriteString(layoutBlog) - } - f.Close() - } - } - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 2 { - t.Fatalf("should be 2 but got %v", len(beeTemplates)) - } - out := bytes.NewBufferString("") - if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - if out.String() != output { - t.Log(out.String()) - t.Fatal("Compare failed") - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -type TestingFileSystem struct { - assetfs *assetfs.AssetFS -} - -func (d TestingFileSystem) Open(name string) (http.File, error) { - return d.assetfs.Open(name) -} - -var outputBinData = ` - - - beego welcome template - - - - -

Hello, blocks!

- - -

Hello, astaxie!

- - - -

Hello

-

This is SomeVar: val

- - -` - -func TestFsBinData(t *testing.T) { - SetTemplateFSFunc(func() http.FileSystem { - return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}} - }) - dir := "views" - if err := AddViewPath("views"); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 3 { - t.Fatalf("should be 3 but got %v", len(beeTemplates)) - } - if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - out := bytes.NewBufferString("") - if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - - if out.String() != outputBinData { - t.Log(out.String()) - t.Fatal("Compare failed") - } -} diff --git a/templatefunc.go b/templatefunc.go deleted file mode 100644 index 9e7c42fc..00000000 --- a/templatefunc.go +++ /dev/null @@ -1,798 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "html" - "html/template" - "net/url" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" - formatDateTimeT = "2006-01-02T15:04:05" -) - -// Substr returns the substr from start to length. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Substr(s string, start, length int) string { - bt := []rune(s) - if start < 0 { - start = 0 - } - if start > len(bt) { - start = start % len(bt) - } - var end int - if (start + length) > (len(bt) - 1) { - end = len(bt) - } else { - end = start + length - } - return string(bt[start:end]) -} - -// HTML2str returns escaping text convert from html. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func HTML2str(html string) string { - - re := regexp.MustCompile(`\<[\S\s]+?\>`) - html = re.ReplaceAllStringFunc(html, strings.ToLower) - - //remove STYLE - re = regexp.MustCompile(`\`) - html = re.ReplaceAllString(html, "") - - //remove SCRIPT - re = regexp.MustCompile(`\`) - html = re.ReplaceAllString(html, "") - - re = regexp.MustCompile(`\<[\S\s]+?\>`) - html = re.ReplaceAllString(html, "\n") - - re = regexp.MustCompile(`\s{2,}`) - html = re.ReplaceAllString(html, "\n") - - return strings.TrimSpace(html) -} - -// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DateFormat(t time.Time, layout string) (datestring string) { - datestring = t.Format(layout) - return -} - -// DateFormat pattern rules. -var datePatterns = []string{ - // year - "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 - - // month - "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 - "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 - "M", "Jan", // A short textual representation of a month, three letters Jan through Dec - "F", "January", // A full textual representation of a month, such as January or March January through December - - // day - "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 - "j", "2", // Day of the month without leading zeros 1 to 31 - - // week - "D", "Mon", // A textual representation of a day, three letters Mon through Sun - "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday - - // time - "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 - "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 - "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 - "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 - - "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm - "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM - - "i", "04", // Minutes with leading zeros 00 to 59 - "s", "05", // Seconds, with leading zeros 00 through 59 - - // time zone - "T", "MST", - "P", "-07:00", - "O", "-0700", - - // RFC 2822 - "r", time.RFC1123Z, -} - -// DateParse Parse Date use PHP time format. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DateParse(dateString, format string) (time.Time, error) { - replacer := strings.NewReplacer(datePatterns...) - format = replacer.Replace(format) - return time.ParseInLocation(format, dateString, time.Local) -} - -// Date takes a PHP like date func to Go's time format. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Date(t time.Time, format string) string { - replacer := strings.NewReplacer(datePatterns...) - format = replacer.Replace(format) - return t.Format(format) -} - -// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. -// Whitespace is trimmed. Used by the template parser as "eq". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Compare(a, b interface{}) (equal bool) { - equal = false - if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { - equal = true - } - return -} - -// CompareNot !Compare -// Deprecated: using pkg/, we will delete this in v2.1.0 -func CompareNot(a, b interface{}) (equal bool) { - return !Compare(a, b) -} - -// NotNil the same as CompareNot -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NotNil(a interface{}) (isNil bool) { - return CompareNot(a, nil) -} - -// GetConfig get the Appconfig -// Deprecated: using pkg/, we will delete this in v2.1.0 -func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { - switch returnType { - case "String": - value = AppConfig.String(key) - case "Bool": - value, err = AppConfig.Bool(key) - case "Int": - value, err = AppConfig.Int(key) - case "Int64": - value, err = AppConfig.Int64(key) - case "Float": - value, err = AppConfig.Float(key) - case "DIY": - value, err = AppConfig.DIY(key) - default: - err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") - } - - if err != nil { - if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) { - err = errors.New("defaultVal type does not match returnType") - } else { - value, err = defaultVal, nil - } - } else if reflect.TypeOf(value).Kind() == reflect.String { - if value == "" { - if reflect.TypeOf(defaultVal).Kind() != reflect.String { - err = errors.New("defaultVal type must be a String if the returnType is a String") - } else { - value = defaultVal.(string) - } - } - } - - return -} - -// Str2html Convert string to template.HTML type. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Str2html(raw string) template.HTML { - return template.HTML(raw) -} - -// Htmlquote returns quoted html string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Htmlquote(text string) string { - //HTML编码为实体符号 - /* - Encodes `text` for raw use in HTML. - >>> htmlquote("<'&\\">") - '<'&">' - */ - - text = html.EscapeString(text) - text = strings.NewReplacer( - `“`, "“", - `”`, "”", - ` `, " ", - ).Replace(text) - - return strings.TrimSpace(text) -} - -// Htmlunquote returns unquoted html string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Htmlunquote(text string) string { - //实体符号解释为HTML - /* - Decodes `text` that's HTML quoted. - >>> htmlunquote('<'&">') - '<\\'&">' - */ - - text = html.UnescapeString(text) - - return strings.TrimSpace(text) -} - -// URLFor returns url string with another registered controller handler with params. -// usage: -// -// URLFor(".index") -// print URLFor("index") -// router /login -// print URLFor("login") -// print URLFor("login", "next","/"") -// router /profile/:username -// print UrlFor("profile", ":username","John Doe") -// result: -// / -// /login -// /login?next=/ -// /user/John%20Doe -// -// more detail http://beego.me/docs/mvc/controller/urlbuilding.md -// Deprecated: using pkg/, we will delete this in v2.1.0 -func URLFor(endpoint string, values ...interface{}) string { - return BeeApp.Handlers.URLFor(endpoint, values...) -} - -// AssetsJs returns script tag with src string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AssetsJs(text string) template.HTML { - - text = "" - - return template.HTML(text) -} - -// AssetsCSS returns stylesheet link tag with src string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AssetsCSS(text string) template.HTML { - - text = "" - - return template.HTML(text) -} - -// ParseForm will parse form values to struct via tag. -// Support for anonymous struct. -func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() { - continue - } - - fieldT := objT.Field(i) - if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { - err := parseFormToStruct(form, fieldT.Type, fieldV) - if err != nil { - return err - } - continue - } - - tags := strings.Split(fieldT.Tag.Get("form"), ",") - var tag string - if len(tags) == 0 || len(tags[0]) == 0 { - tag = fieldT.Name - } else if tags[0] == "-" { - continue - } else { - tag = tags[0] - } - - formValues := form[tag] - var value string - if len(formValues) == 0 { - defaultValue := fieldT.Tag.Get("default") - if defaultValue != "" { - value = defaultValue - } else { - continue - } - } - if len(formValues) == 1 { - value = formValues[0] - if value == "" { - continue - } - } - - switch fieldT.Type.Kind() { - case reflect.Bool: - if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { - fieldV.SetBool(true) - continue - } - if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { - fieldV.SetBool(false) - continue - } - b, err := strconv.ParseBool(value) - if err != nil { - return err - } - fieldV.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - fieldV.SetInt(x) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return err - } - fieldV.SetUint(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(value, 64) - if err != nil { - return err - } - fieldV.SetFloat(x) - case reflect.Interface: - fieldV.Set(reflect.ValueOf(value)) - case reflect.String: - fieldV.SetString(value) - case reflect.Struct: - switch fieldT.Type.String() { - case "time.Time": - var ( - t time.Time - err error - ) - if len(value) >= 25 { - value = value[:25] - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if len(value) >= 19 { - if strings.Contains(value, "T") { - value = value[:19] - t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) - } else { - value = value[:19] - t, err = time.ParseInLocation(formatDateTime, value, time.Local) - } - } else if len(value) >= 10 { - if len(value) > 10 { - value = value[:10] - } - t, err = time.ParseInLocation(formatDate, value, time.Local) - } else if len(value) >= 8 { - if len(value) > 8 { - value = value[:8] - } - t, err = time.ParseInLocation(formatTime, value, time.Local) - } - if err != nil { - return err - } - fieldV.Set(reflect.ValueOf(t)) - } - case reflect.Slice: - if fieldT.Type == sliceOfInts { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - val, err := strconv.Atoi(formVals[i]) - if err != nil { - return err - } - fieldV.Index(i).SetInt(int64(val)) - } - } else if fieldT.Type == sliceOfStrings { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - fieldV.Index(i).SetString(formVals[i]) - } - } - } - } - return nil -} - -// ParseForm will parse form values to struct via tag. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ParseForm(form url.Values, obj interface{}) error { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return fmt.Errorf("%v must be a struct pointer", obj) - } - objT = objT.Elem() - objV = objV.Elem() - - return parseFormToStruct(form, objT, objV) -} - -var sliceOfInts = reflect.TypeOf([]int(nil)) -var sliceOfStrings = reflect.TypeOf([]string(nil)) - -var unKind = map[reflect.Kind]bool{ - reflect.Uintptr: true, - reflect.Complex64: true, - reflect.Complex128: true, - reflect.Array: true, - reflect.Chan: true, - reflect.Func: true, - reflect.Map: true, - reflect.Ptr: true, - reflect.Slice: true, - reflect.Struct: true, - reflect.UnsafePointer: true, -} - -// RenderForm will render object to form html. -// obj must be a struct pointer. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RenderForm(obj interface{}) template.HTML { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return template.HTML("") - } - objT = objT.Elem() - objV = objV.Elem() - - var raw []string - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() || unKind[fieldV.Kind()] { - continue - } - - fieldT := objT.Field(i) - - label, name, fType, id, class, ignored, required := parseFormTag(fieldT) - if ignored { - continue - } - - raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) - } - return template.HTML(strings.Join(raw, "
")) -} - -// renderFormField returns a string containing HTML of a single form field. -func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { - if id != "" { - id = " id=\"" + id + "\"" - } - - if class != "" { - class = " class=\"" + class + "\"" - } - - requiredString := "" - if required { - requiredString = " required" - } - - if isValidForInput(fType) { - return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) - } - - return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) -} - -// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. -func isValidForInput(fType string) bool { - validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color") - for _, validType := range validInputTypes { - if fType == validType { - return true - } - } - return false -} - -// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. -// returned are the form label, name-property, type and wether the field should be ignored. -func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { - tags := strings.Split(fieldT.Tag.Get("form"), ",") - label = fieldT.Name + ": " - name = fieldT.Name - fType = "text" - ignored = false - id = fieldT.Tag.Get("id") - class = fieldT.Tag.Get("class") - - required = false - requiredField := fieldT.Tag.Get("required") - if requiredField != "-" && requiredField != "" { - required, _ = strconv.ParseBool(requiredField) - } - - switch len(tags) { - case 1: - if tags[0] == "-" { - ignored = true - } - if len(tags[0]) > 0 { - name = tags[0] - } - case 2: - if len(tags[0]) > 0 { - name = tags[0] - } - if len(tags[1]) > 0 { - fType = tags[1] - } - case 3: - if len(tags[0]) > 0 { - name = tags[0] - } - if len(tags[1]) > 0 { - fType = tags[1] - } - if len(tags[2]) > 0 { - label = tags[2] - } - } - - return -} - -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - -// go1.2 added template funcs. begin -var ( - errBadComparisonType = errors.New("invalid type for comparison") - errBadComparison = errors.New("incompatible types for comparison") - errNoComparison = errors.New("missing argument for comparison") -) - -type kind int - -const ( - invalidKind kind = iota - boolKind - complexKind - intKind - floatKind - stringKind - uintKind -) - -func basicKind(v reflect.Value) (kind, error) { - switch v.Kind() { - case reflect.Bool: - return boolKind, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return intKind, nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return uintKind, nil - case reflect.Float32, reflect.Float64: - return floatKind, nil - case reflect.Complex64, reflect.Complex128: - return complexKind, nil - case reflect.String: - return stringKind, nil - } - return invalidKind, errBadComparisonType -} - -// eq evaluates the comparison a == b || a == c || ... -func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { - v1 := reflect.ValueOf(arg1) - k1, err := basicKind(v1) - if err != nil { - return false, err - } - if len(arg2) == 0 { - return false, errNoComparison - } - for _, arg := range arg2 { - v2 := reflect.ValueOf(arg) - k2, err := basicKind(v2) - if err != nil { - return false, err - } - if k1 != k2 { - return false, errBadComparison - } - truth := false - switch k1 { - case boolKind: - truth = v1.Bool() == v2.Bool() - case complexKind: - truth = v1.Complex() == v2.Complex() - case floatKind: - truth = v1.Float() == v2.Float() - case intKind: - truth = v1.Int() == v2.Int() - case stringKind: - truth = v1.String() == v2.String() - case uintKind: - truth = v1.Uint() == v2.Uint() - default: - panic("invalid kind") - } - if truth { - return true, nil - } - } - return false, nil -} - -// ne evaluates the comparison a != b. -func ne(arg1, arg2 interface{}) (bool, error) { - // != is the inverse of ==. - equal, err := eq(arg1, arg2) - return !equal, err -} - -// lt evaluates the comparison a < b. -func lt(arg1, arg2 interface{}) (bool, error) { - v1 := reflect.ValueOf(arg1) - k1, err := basicKind(v1) - if err != nil { - return false, err - } - v2 := reflect.ValueOf(arg2) - k2, err := basicKind(v2) - if err != nil { - return false, err - } - if k1 != k2 { - return false, errBadComparison - } - truth := false - switch k1 { - case boolKind, complexKind: - return false, errBadComparisonType - case floatKind: - truth = v1.Float() < v2.Float() - case intKind: - truth = v1.Int() < v2.Int() - case stringKind: - truth = v1.String() < v2.String() - case uintKind: - truth = v1.Uint() < v2.Uint() - default: - panic("invalid kind") - } - return truth, nil -} - -// le evaluates the comparison <= b. -func le(arg1, arg2 interface{}) (bool, error) { - // <= is < or ==. - lessThan, err := lt(arg1, arg2) - if lessThan || err != nil { - return lessThan, err - } - return eq(arg1, arg2) -} - -// gt evaluates the comparison a > b. -func gt(arg1, arg2 interface{}) (bool, error) { - // > is the inverse of <=. - lessOrEqual, err := le(arg1, arg2) - if err != nil { - return false, err - } - return !lessOrEqual, nil -} - -// ge evaluates the comparison a >= b. -func ge(arg1, arg2 interface{}) (bool, error) { - // >= is the inverse of <. - lessThan, err := lt(arg1, arg2) - if err != nil { - return false, err - } - return !lessThan, nil -} - -// MapGet getting value from map by keys -// usage: -// Data["m"] = M{ -// "a": 1, -// "1": map[string]float64{ -// "c": 4, -// }, -// } -// -// {{ map_get m "a" }} // return 1 -// {{ map_get m 1 "c" }} // return 4 -// Deprecated: using pkg/, we will delete this in v2.1.0 -func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { - arg1Type := reflect.TypeOf(arg1) - arg1Val := reflect.ValueOf(arg1) - - if arg1Type.Kind() == reflect.Map && len(arg2) > 0 { - // check whether arg2[0] type equals to arg1 key type - // if they are different, make conversion - arg2Val := reflect.ValueOf(arg2[0]) - arg2Type := reflect.TypeOf(arg2[0]) - if arg2Type.Kind() != arg1Type.Key().Kind() { - // convert arg2Value to string - var arg2ConvertedVal interface{} - arg2String := fmt.Sprintf("%v", arg2[0]) - - // convert string representation to any other type - switch arg1Type.Key().Kind() { - case reflect.Bool: - arg2ConvertedVal, _ = strconv.ParseBool(arg2String) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64) - case reflect.Float32, reflect.Float64: - arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64) - case reflect.String: - arg2ConvertedVal = arg2String - default: - arg2ConvertedVal = arg2Val.Interface() - } - arg2Val = reflect.ValueOf(arg2ConvertedVal) - } - - storedVal := arg1Val.MapIndex(arg2Val) - - if storedVal.IsValid() { - var result interface{} - - switch arg1Type.Elem().Kind() { - case reflect.Bool: - result = storedVal.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - result = storedVal.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - result = storedVal.Uint() - case reflect.Float32, reflect.Float64: - result = storedVal.Float() - case reflect.String: - result = storedVal.String() - default: - result = storedVal.Interface() - } - - // if there is more keys, handle this recursively - if len(arg2) > 1 { - return MapGet(result, arg2[1:]...) - } - return result, nil - } - return nil, nil - - } - return nil, nil -} diff --git a/templatefunc_test.go b/templatefunc_test.go deleted file mode 100644 index b4c19c2e..00000000 --- a/templatefunc_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "html/template" - "net/url" - "reflect" - "testing" - "time" -) - -func TestSubstr(t *testing.T) { - s := `012345` - if Substr(s, 0, 2) != "01" { - t.Error("should be equal") - } - if Substr(s, 0, 100) != "012345" { - t.Error("should be equal") - } - if Substr(s, 12, 100) != "012345" { - t.Error("should be equal") - } -} - -func TestHtml2str(t *testing.T) { - h := `<123> 123\n - - - \n` - if HTML2str(h) != "123\\n\n\\n" { - t.Error("should be equal") - } -} - -func TestDateFormat(t *testing.T) { - ts := "Mon, 01 Jul 2013 13:27:42 CST" - tt, _ := time.Parse(time.RFC1123, ts) - - if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } -} - -func TestDate(t *testing.T) { - ts := "Mon, 01 Jul 2013 13:27:42 CST" - tt, _ := time.Parse(time.RFC1123, ts) - - if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } - if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { - t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) - } - if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { - t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) - } - if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { - t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) - } -} - -func TestCompareRelated(t *testing.T) { - if !Compare("abc", "abc") { - t.Error("should be equal") - } - if Compare("abc", "aBc") { - t.Error("should be not equal") - } - if !Compare("1", 1) { - t.Error("should be equal") - } - if CompareNot("abc", "abc") { - t.Error("should be equal") - } - if !CompareNot("abc", "aBc") { - t.Error("should be not equal") - } - if !NotNil("a string") { - t.Error("should not be nil") - } -} - -func TestHtmlquote(t *testing.T) { - h := `<' ”“&">` - s := `<' ”“&">` - if Htmlquote(s) != h { - t.Error("should be equal") - } -} - -func TestHtmlunquote(t *testing.T) { - h := `<' ”“&">` - s := `<' ”“&">` - if Htmlunquote(h) != s { - t.Error("should be equal") - } -} - -func TestParseForm(t *testing.T) { - type ExtendInfo struct { - Hobby []string `form:"hobby"` - Memo string - } - - type OtherInfo struct { - Organization string `form:"organization"` - Title string `form:"title"` - ExtendInfo - } - - type user struct { - ID int `form:"-"` - tag string `form:"tag"` - Name interface{} `form:"username"` - Age int `form:"age,text"` - Email string - Intro string `form:",textarea"` - StrBool bool `form:"strbool"` - Date time.Time `form:"date,2006-01-02"` - OtherInfo - } - - u := user{} - form := url.Values{ - "ID": []string{"1"}, - "-": []string{"1"}, - "tag": []string{"no"}, - "username": []string{"test"}, - "age": []string{"40"}, - "Email": []string{"test@gmail.com"}, - "Intro": []string{"I am an engineer!"}, - "strbool": []string{"yes"}, - "date": []string{"2014-11-12"}, - "organization": []string{"beego"}, - "title": []string{"CXO"}, - "hobby": []string{"", "Basketball", "Football"}, - "memo": []string{"nothing"}, - } - if err := ParseForm(form, u); err == nil { - t.Fatal("nothing will be changed") - } - if err := ParseForm(form, &u); err != nil { - t.Fatal(err) - } - if u.ID != 0 { - t.Errorf("ID should equal 0 but got %v", u.ID) - } - if len(u.tag) != 0 { - t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) - } - if u.Name.(string) != "test" { - t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) - } - if u.Age != 40 { - t.Errorf("Age should equal 40 but got %v", u.Age) - } - if u.Email != "test@gmail.com" { - t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) - } - if u.Intro != "I am an engineer!" { - t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) - } - if !u.StrBool { - t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) - } - y, m, d := u.Date.Date() - if y != 2014 || m.String() != "November" || d != 12 { - t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) - } - if u.Organization != "beego" { - t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) - } - if u.Title != "CXO" { - t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) - } - if u.Hobby[0] != "" { - t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) - } - if u.Hobby[1] != "Basketball" { - t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) - } - if u.Hobby[2] != "Football" { - t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) - } - if len(u.Memo) != 0 { - t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) - } -} - -func TestRenderForm(t *testing.T) { - type user struct { - ID int `form:"-"` - Name interface{} `form:"username"` - Age int `form:"age,text,年龄:"` - Sex string - Email []string - Intro string `form:",textarea"` - Ignored string `form:"-"` - } - - u := user{Name: "test", Intro: "Some Text"} - output := RenderForm(u) - if output != template.HTML("") { - t.Errorf("output should be empty but got %v", output) - } - output = RenderForm(&u) - result := template.HTML( - `Name:
` + - `年龄:
` + - `Sex:
` + - `Intro: `) - if output != result { - t.Errorf("output should equal `%v` but got `%v`", result, output) - } -} - -func TestRenderFormField(t *testing.T) { - html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) - if html != `Label: ` { - t.Errorf("Wrong html output for input[type=text]: %v ", html) - } - - html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) - if html != `Label: ` { - t.Errorf("Wrong html output for textarea: %v ", html) - } - - html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) - if html != `Label: ` { - t.Errorf("Wrong html output for textarea: %v ", html) - } -} - -func TestParseFormTag(t *testing.T) { - // create struct to contain field with different types of struct-tag `form` - type user struct { - All int `form:"name,text,年龄:"` - NoName int `form:",hidden,年龄:"` - OnlyLabel int `form:",,年龄:"` - OnlyName int `form:"name" id:"name" class:"form-name"` - Ignored int `form:"-"` - Required int `form:"name" required:"true"` - IgnoreRequired int `form:"name"` - NotRequired int `form:"name" required:"false"` - } - - objT := reflect.TypeOf(&user{}).Elem() - - label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) - if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { - t.Errorf("Form Tag with name, label and type was not correctly parsed.") - } - - label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) - if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { - t.Errorf("Form Tag with label and type but without name was not correctly parsed.") - } - - label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) - if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { - t.Errorf("Form Tag containing only label was not correctly parsed.") - } - - label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) - if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && - id == "name" && class == "form-name") { - t.Errorf("Form Tag containing only name was not correctly parsed.") - } - - _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) - if !ignored { - t.Errorf("Form Tag that should be ignored was not correctly parsed.") - } - - _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) - if !(name == "name" && required) { - t.Errorf("Form Tag containing only name and required was not correctly parsed.") - } - - _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) - if !(name == "name" && !required) { - t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") - } - - _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) - if !(name == "name" && !required) { - t.Errorf("Form Tag containing only name and not required was not correctly parsed.") - } - -} - -func TestMapGet(t *testing.T) { - // test one level map - m1 := map[string]int64{ - "a": 1, - "1": 2, - } - - if res, err := MapGet(m1, "a"); err == nil { - if res.(int64) != 1 { - t.Errorf("Should return 1, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - if res, err := MapGet(m1, "1"); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - if res, err := MapGet(m1, 1); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // test 2 level map - m2 := M{ - "1": map[string]float64{ - "2": 3.5, - }, - } - - if res, err := MapGet(m2, 1, 2); err == nil { - if res.(float64) != 3.5 { - t.Errorf("Should return 3.5, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // test 5 level map - m5 := M{ - "1": M{ - "2": M{ - "3": M{ - "4": M{ - "5": 1.2, - }, - }, - }, - }, - } - - if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { - if res.(float64) != 1.2 { - t.Errorf("Should return 1.2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // check whether element not exists in map - if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { - if res != nil { - t.Errorf("Should return nil, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } -} diff --git a/testing/assertions.go b/testing/assertions.go deleted file mode 100644 index 96c5d4dd..00000000 --- a/testing/assertions.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testing diff --git a/testing/client.go b/testing/client.go deleted file mode 100644 index c3737e9c..00000000 --- a/testing/client.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testing - -import ( - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/httplib" -) - -var port = "" -var baseURL = "http://localhost:" - -// TestHTTPRequest beego test request client -type TestHTTPRequest struct { - httplib.BeegoHTTPRequest -} - -func getPort() string { - if port == "" { - config, err := config.NewConfig("ini", "../conf/app.conf") - if err != nil { - return "8080" - } - port = config.String("httpport") - return port - } - return port -} - -// Get returns test client in GET method -func Get(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)} -} - -// Post returns test client in POST method -func Post(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)} -} - -// Put returns test client in PUT method -func Put(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)} -} - -// Delete returns test client in DELETE method -func Delete(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)} -} - -// Head returns test client in HEAD method -func Head(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)} -} diff --git a/toolbox/healthcheck.go b/toolbox/healthcheck.go deleted file mode 100644 index e3544b3a..00000000 --- a/toolbox/healthcheck.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package toolbox healthcheck -// -// type DatabaseCheck struct { -// } -// -// func (dc *DatabaseCheck) Check() error { -// if dc.isConnected() { -// return nil -// } else { -// return errors.New("can't connect database") -// } -// } -// -// AddHealthCheck("database",&DatabaseCheck{}) -// -// more docs: http://beego.me/docs/module/toolbox.md -package toolbox - -// AdminCheckList holds health checker map -var AdminCheckList map[string]HealthChecker - -// HealthChecker health checker interface -type HealthChecker interface { - Check() error -} - -// AddHealthCheck add health checker with name string -func AddHealthCheck(name string, hc HealthChecker) { - AdminCheckList[name] = hc -} - -func init() { - AdminCheckList = make(map[string]HealthChecker) -} diff --git a/toolbox/profile.go b/toolbox/profile.go deleted file mode 100644 index 06e40ede..00000000 --- a/toolbox/profile.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "fmt" - "io" - "log" - "os" - "path" - "runtime" - "runtime/debug" - "runtime/pprof" - "strconv" - "time" -) - -var startTime = time.Now() -var pid int - -func init() { - pid = os.Getpid() -} - -// ProcessInput parse input command string -func ProcessInput(input string, w io.Writer) { - switch input { - case "lookup goroutine": - p := pprof.Lookup("goroutine") - p.WriteTo(w, 2) - case "lookup heap": - p := pprof.Lookup("heap") - p.WriteTo(w, 2) - case "lookup threadcreate": - p := pprof.Lookup("threadcreate") - p.WriteTo(w, 2) - case "lookup block": - p := pprof.Lookup("block") - p.WriteTo(w, 2) - case "get cpuprof": - GetCPUProfile(w) - case "get memprof": - MemProf(w) - case "gc summary": - PrintGCSummary(w) - } -} - -// MemProf record memory profile in pprof -func MemProf(w io.Writer) { - filename := "mem-" + strconv.Itoa(pid) + ".memprof" - if f, err := os.Create(filename); err != nil { - fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error()) - log.Fatal("record heap profile failed: ", err) - } else { - runtime.GC() - pprof.WriteHeapProfile(f) - f.Close() - fmt.Fprintf(w, "create heap profile %s \n", filename) - _, fl := path.Split(os.Args[0]) - fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) - } -} - -// GetCPUProfile start cpu profile monitor -func GetCPUProfile(w io.Writer) { - sec := 30 - filename := "cpu-" + strconv.Itoa(pid) + ".pprof" - f, err := os.Create(filename) - if err != nil { - fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) - log.Fatal("record cpu profile failed: ", err) - } - pprof.StartCPUProfile(f) - time.Sleep(time.Duration(sec) * time.Second) - pprof.StopCPUProfile() - - fmt.Fprintf(w, "create cpu profile %s \n", filename) - _, fl := path.Split(os.Args[0]) - fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) -} - -// PrintGCSummary print gc information to io.Writer -func PrintGCSummary(w io.Writer) { - memStats := &runtime.MemStats{} - runtime.ReadMemStats(memStats) - gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)} - debug.ReadGCStats(gcstats) - - printGC(memStats, gcstats, w) -} - -func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { - - if gcstats.NumGC > 0 { - lastPause := gcstats.Pause[0] - elapsed := time.Now().Sub(startTime) - overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 - allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() - - fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", - gcstats.NumGC, - toS(lastPause), - toS(avg(gcstats.Pause)), - overhead, - toH(memStats.Alloc), - toH(memStats.Sys), - toH(uint64(allocatedRate)), - toS(gcstats.PauseQuantiles[94]), - toS(gcstats.PauseQuantiles[98]), - toS(gcstats.PauseQuantiles[99])) - } else { - // while GC has disabled - elapsed := time.Now().Sub(startTime) - allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() - - fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", - toH(memStats.Alloc), - toH(memStats.Sys), - toH(uint64(allocatedRate))) - } -} - -func avg(items []time.Duration) time.Duration { - var sum time.Duration - for _, item := range items { - sum += item - } - return time.Duration(int64(sum) / int64(len(items))) -} - -// format bytes number friendly -func toH(bytes uint64) string { - switch { - case bytes < 1024: - return fmt.Sprintf("%dB", bytes) - case bytes < 1024*1024: - return fmt.Sprintf("%.2fK", float64(bytes)/1024) - case bytes < 1024*1024*1024: - return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024) - default: - return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) - } -} - -// short string format -func toS(d time.Duration) string { - - u := uint64(d) - if u < uint64(time.Second) { - switch { - case u == 0: - return "0" - case u < uint64(time.Microsecond): - return fmt.Sprintf("%.2fns", float64(u)) - case u < uint64(time.Millisecond): - return fmt.Sprintf("%.2fus", float64(u)/1000) - default: - return fmt.Sprintf("%.2fms", float64(u)/1000/1000) - } - } else { - switch { - case u < uint64(time.Minute): - return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) - case u < uint64(time.Hour): - return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) - default: - return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) - } - } - -} diff --git a/toolbox/profile_test.go b/toolbox/profile_test.go deleted file mode 100644 index 07a20c4e..00000000 --- a/toolbox/profile_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "os" - "testing" -) - -func TestProcessInput(t *testing.T) { - ProcessInput("lookup goroutine", os.Stdout) - ProcessInput("lookup heap", os.Stdout) - ProcessInput("lookup threadcreate", os.Stdout) - ProcessInput("lookup block", os.Stdout) - ProcessInput("gc summary", os.Stdout) -} diff --git a/toolbox/statistics.go b/toolbox/statistics.go deleted file mode 100644 index fd73dfb3..00000000 --- a/toolbox/statistics.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "fmt" - "sync" - "time" -) - -// Statistics struct -type Statistics struct { - RequestURL string - RequestController string - RequestNum int64 - MinTime time.Duration - MaxTime time.Duration - TotalTime time.Duration -} - -// URLMap contains several statistics struct to log different data -type URLMap struct { - lock sync.RWMutex - LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit - urlmap map[string]map[string]*Statistics -} - -// AddStatistics add statistics task. -// it needs request method, request url, request controller and statistics time duration -func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { - m.lock.Lock() - defer m.lock.Unlock() - if method, ok := m.urlmap[requestURL]; ok { - if s, ok := method[requestMethod]; ok { - s.RequestNum++ - if s.MaxTime < requesttime { - s.MaxTime = requesttime - } - if s.MinTime > requesttime { - s.MinTime = requesttime - } - s.TotalTime += requesttime - } else { - nb := &Statistics{ - RequestURL: requestURL, - RequestController: requestController, - RequestNum: 1, - MinTime: requesttime, - MaxTime: requesttime, - TotalTime: requesttime, - } - m.urlmap[requestURL][requestMethod] = nb - } - - } else { - if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { - return - } - methodmap := make(map[string]*Statistics) - nb := &Statistics{ - RequestURL: requestURL, - RequestController: requestController, - RequestNum: 1, - MinTime: requesttime, - MaxTime: requesttime, - TotalTime: requesttime, - } - methodmap[requestMethod] = nb - m.urlmap[requestURL] = methodmap - } -} - -// GetMap put url statistics result in io.Writer -func (m *URLMap) GetMap() map[string]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - - var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} - - var resultLists [][]string - content := make(map[string]interface{}) - content["Fields"] = fields - - for k, v := range m.urlmap { - for kk, vv := range v { - result := []string{ - fmt.Sprintf("% -50s", k), - fmt.Sprintf("% -10s", kk), - fmt.Sprintf("% -16d", vv.RequestNum), - fmt.Sprintf("%d", vv.TotalTime), - fmt.Sprintf("% -16s", toS(vv.TotalTime)), - fmt.Sprintf("%d", vv.MaxTime), - fmt.Sprintf("% -16s", toS(vv.MaxTime)), - fmt.Sprintf("%d", vv.MinTime), - fmt.Sprintf("% -16s", toS(vv.MinTime)), - fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), - fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), - } - resultLists = append(resultLists, result) - } - } - content["Data"] = resultLists - return content -} - -// GetMapData return all mapdata -func (m *URLMap) GetMapData() []map[string]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - - var resultLists []map[string]interface{} - - for k, v := range m.urlmap { - for kk, vv := range v { - result := map[string]interface{}{ - "request_url": k, - "method": kk, - "times": vv.RequestNum, - "total_time": toS(vv.TotalTime), - "max_time": toS(vv.MaxTime), - "min_time": toS(vv.MinTime), - "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), - } - resultLists = append(resultLists, result) - } - } - return resultLists -} - -// StatisticsMap hosld global statistics data map -var StatisticsMap *URLMap - -func init() { - StatisticsMap = &URLMap{ - urlmap: make(map[string]map[string]*Statistics), - } -} diff --git a/toolbox/statistics_test.go b/toolbox/statistics_test.go deleted file mode 100644 index ac29476c..00000000 --- a/toolbox/statistics_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "encoding/json" - "testing" - "time" -) - -func TestStatics(t *testing.T) { - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) - StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) - StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) - StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) - t.Log(StatisticsMap.GetMap()) - - data := StatisticsMap.GetMapData() - b, err := json.Marshal(data) - if err != nil { - t.Errorf(err.Error()) - } - - t.Log(string(b)) -} diff --git a/toolbox/task.go b/toolbox/task.go deleted file mode 100644 index fb2c5f16..00000000 --- a/toolbox/task.go +++ /dev/null @@ -1,640 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "log" - "math" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -// bounds provides a range of acceptable values (plus a map of name to value). -type bounds struct { - min, max uint - names map[string]uint -} - -// The bounds for each field. -var ( - AdminTaskList map[string]Tasker - taskLock sync.RWMutex - stop chan bool - changed chan bool - isstart bool - seconds = bounds{0, 59, nil} - minutes = bounds{0, 59, nil} - hours = bounds{0, 23, nil} - days = bounds{1, 31, nil} - months = bounds{1, 12, map[string]uint{ - "jan": 1, - "feb": 2, - "mar": 3, - "apr": 4, - "may": 5, - "jun": 6, - "jul": 7, - "aug": 8, - "sep": 9, - "oct": 10, - "nov": 11, - "dec": 12, - }} - weeks = bounds{0, 6, map[string]uint{ - "sun": 0, - "mon": 1, - "tue": 2, - "wed": 3, - "thu": 4, - "fri": 5, - "sat": 6, - }} -) - -const ( - // Set the top bit if a star was included in the expression. - starBit = 1 << 63 -) - -// Schedule time taks schedule -type Schedule struct { - Second uint64 - Minute uint64 - Hour uint64 - Day uint64 - Month uint64 - Week uint64 -} - -// TaskFunc task func type -type TaskFunc func() error - -// Tasker task interface -type Tasker interface { - GetSpec() string - GetStatus() string - Run() error - SetNext(time.Time) - GetNext() time.Time - SetPrev(time.Time) - GetPrev() time.Time -} - -// task error -type taskerr struct { - t time.Time - errinfo string -} - -// Task task struct -// It's not a thread-safe structure. -// Only nearest errors will be saved in ErrList -type Task struct { - Taskname string - Spec *Schedule - SpecStr string - DoFunc TaskFunc - Prev time.Time - Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution -} - -// NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { - - task := &Task{ - Taskname: tname, - DoFunc: f, - // Make configurable - ErrLimit: 100, - SpecStr: spec, - // we only store the pointer, so it won't use too many space - Errlist: make([]*taskerr, 100, 100), - } - task.SetCron(spec) - return task -} - -// GetSpec get spec string -func (t *Task) GetSpec() string { - return t.SpecStr -} - -// GetStatus get current task status -func (t *Task) GetStatus() string { - var str string - for _, v := range t.Errlist { - str += v.t.String() + ":" + v.errinfo + "
" - } - return str -} - -// Run run all tasks -func (t *Task) Run() error { - err := t.DoFunc() - if err != nil { - index := t.errCnt % t.ErrLimit - t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} - t.errCnt++ - } - return err -} - -// SetNext set next time for this task -func (t *Task) SetNext(now time.Time) { - t.Next = t.Spec.Next(now) -} - -// GetNext get the next call time of this task -func (t *Task) GetNext() time.Time { - return t.Next -} - -// SetPrev set prev time of this task -func (t *Task) SetPrev(now time.Time) { - t.Prev = now -} - -// GetPrev get prev time of this task -func (t *Task) GetPrev() time.Time { - return t.Prev -} - -// six columns mean: -// second:0-59 -// minute:0-59 -// hour:1-23 -// day:1-31 -// month:1-12 -// week:0-6(0 means Sunday) - -// SetCron some signals: -// *: any time -// ,:  separate signal -//   -:duration -// /n : do as n times of time duration -///////////////////////////////////////////////////////// -// 0/30 * * * * * every 30s -// 0 43 21 * * * 21:43 -// 0 15 05 * * *    05:15 -// 0 0 17 * * * 17:00 -// 0 0 17 * * 1 17:00 in every Monday -// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday -// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month -// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month -// 0 42 4 1 * *     4:42 on the 1st day of month -// 0 0 21 * * 1-6   21:00 from Monday to Saturday -// 0 0,10,20,30,40,50 * * * *  every 10 min duration -// 0 */10 * * * *        every 10 min duration -// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time -// 0 0 1 * * *         1:00 -// 0 0 */1 * * *        0 min of hour in 1 hour duration -// 0 0 * * * *         0 min of hour in 1 hour duration -// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 -// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month -func (t *Task) SetCron(spec string) { - t.Spec = t.parse(spec) -} - -func (t *Task) parse(spec string) *Schedule { - if len(spec) > 0 && spec[0] == '@' { - return t.parseSpec(spec) - } - // Split on whitespace. We require 5 or 6 fields. - // (second) (minute) (hour) (day of month) (month) (day of week, optional) - fields := strings.Fields(spec) - if len(fields) != 5 && len(fields) != 6 { - log.Panicf("Expected 5 or 6 fields, found %d: %s", len(fields), spec) - } - - // If a sixth field is not provided (DayOfWeek), then it is equivalent to star. - if len(fields) == 5 { - fields = append(fields, "*") - } - - schedule := &Schedule{ - Second: getField(fields[0], seconds), - Minute: getField(fields[1], minutes), - Hour: getField(fields[2], hours), - Day: getField(fields[3], days), - Month: getField(fields[4], months), - Week: getField(fields[5], weeks), - } - - return schedule -} - -func (t *Task) parseSpec(spec string) *Schedule { - switch spec { - case "@yearly", "@annually": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: 1 << days.min, - Month: 1 << months.min, - Week: all(weeks), - } - - case "@monthly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: 1 << days.min, - Month: all(months), - Week: all(weeks), - } - - case "@weekly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: all(days), - Month: all(months), - Week: 1 << weeks.min, - } - - case "@daily", "@midnight": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: all(days), - Month: all(months), - Week: all(weeks), - } - - case "@hourly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: all(hours), - Day: all(days), - Month: all(months), - Week: all(weeks), - } - } - log.Panicf("Unrecognized descriptor: %s", spec) - return nil -} - -// Next set schedule to next time -func (s *Schedule) Next(t time.Time) time.Time { - - // Start at the earliest possible time (the upcoming second). - t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) - - // This flag indicates whether a field has been incremented. - added := false - - // If no time is found within five years, return zero. - yearLimit := t.Year() + 5 - -WRAP: - if t.Year() > yearLimit { - return time.Time{} - } - - // Find the first applicable month. - // If it's this month, then do nothing. - for 1< 0 - dowMatch = 1< 0 - ) - - if s.Day&starBit > 0 || s.Week&starBit > 0 { - return domMatch && dowMatch - } - return domMatch || dowMatch -} - -// StartTask start all tasks -func StartTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - //If already started, no need to start another goroutine. - return - } - isstart = true - go run() -} - -func run() { - now := time.Now().Local() - for _, t := range AdminTaskList { - t.SetNext(now) - } - - for { - // we only use RLock here because NewMapSorter copy the reference, do not change any thing - taskLock.RLock() - sortList := NewMapSorter(AdminTaskList) - taskLock.RUnlock() - sortList.Sort() - var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { - // If there are no entries yet, just sleep - it still handles new entries - // and stop requests. - effective = now.AddDate(10, 0, 0) - } else { - effective = sortList.Vals[0].GetNext() - } - select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext() != effective { - break - } - go e.Run() - e.SetPrev(e.GetNext()) - e.SetNext(effective) - } - continue - case <-changed: - now = time.Now().Local() - taskLock.Lock() - for _, t := range AdminTaskList { - t.SetNext(now) - } - taskLock.Unlock() - continue - case <-stop: - return - } - } -} - -// StopTask stop all tasks -func StopTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - isstart = false - stop <- true - } - -} - -// AddTask add task with name -func AddTask(taskname string, t Tasker) { - taskLock.Lock() - defer taskLock.Unlock() - t.SetNext(time.Now().Local()) - AdminTaskList[taskname] = t - if isstart { - changed <- true - } -} - -// DeleteTask delete task with name -func DeleteTask(taskname string) { - taskLock.Lock() - defer taskLock.Unlock() - delete(AdminTaskList, taskname) - if isstart { - changed <- true - } -} - -// MapSorter sort map for tasker -type MapSorter struct { - Keys []string - Vals []Tasker -} - -// NewMapSorter create new tasker map -func NewMapSorter(m map[string]Tasker) *MapSorter { - ms := &MapSorter{ - Keys: make([]string, 0, len(m)), - Vals: make([]Tasker, 0, len(m)), - } - for k, v := range m { - ms.Keys = append(ms.Keys, k) - ms.Vals = append(ms.Vals, v) - } - return ms -} - -// Sort sort tasker map -func (ms *MapSorter) Sort() { - sort.Sort(ms) -} - -func (ms *MapSorter) Len() int { return len(ms.Keys) } -func (ms *MapSorter) Less(i, j int) bool { - if ms.Vals[i].GetNext().IsZero() { - return false - } - if ms.Vals[j].GetNext().IsZero() { - return true - } - return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) -} -func (ms *MapSorter) Swap(i, j int) { - ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] - ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] -} - -func getField(field string, r bounds) uint64 { - // list = range {"," range} - var bits uint64 - ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) - for _, expr := range ranges { - bits |= getRange(expr, r) - } - return bits -} - -// getRange returns the bits indicated by the given expression: -// number | number "-" number [ "/" number ] -func getRange(expr string, r bounds) uint64 { - - var ( - start, end, step uint - rangeAndStep = strings.Split(expr, "/") - lowAndHigh = strings.Split(rangeAndStep[0], "-") - singleDigit = len(lowAndHigh) == 1 - ) - - var extrastar uint64 - if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { - start = r.min - end = r.max - extrastar = starBit - } else { - start = parseIntOrName(lowAndHigh[0], r.names) - switch len(lowAndHigh) { - case 1: - end = start - case 2: - end = parseIntOrName(lowAndHigh[1], r.names) - default: - log.Panicf("Too many hyphens: %s", expr) - } - } - - switch len(rangeAndStep) { - case 1: - step = 1 - case 2: - step = mustParseInt(rangeAndStep[1]) - - // Special handling: "N/step" means "N-max/step". - if singleDigit { - end = r.max - } - default: - log.Panicf("Too many slashes: %s", expr) - } - - if start < r.min { - log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr) - } - if end > r.max { - log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr) - } - if start > end { - log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr) - } - - return getBits(start, end, step) | extrastar -} - -// parseIntOrName returns the (possibly-named) integer contained in expr. -func parseIntOrName(expr string, names map[string]uint) uint { - if names != nil { - if namedInt, ok := names[strings.ToLower(expr)]; ok { - return namedInt - } - } - return mustParseInt(expr) -} - -// mustParseInt parses the given expression as an int or panics. -func mustParseInt(expr string) uint { - num, err := strconv.Atoi(expr) - if err != nil { - log.Panicf("Failed to parse int from %s: %s", expr, err) - } - if num < 0 { - log.Panicf("Negative number (%d) not allowed: %s", num, expr) - } - - return uint(num) -} - -// getBits sets all bits in the range [min, max], modulo the given step size. -func getBits(min, max, step uint) uint64 { - var bits uint64 - - // If step is 1, use shifts. - if step == 1 { - return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) - } - - // Else, use a simple loop. - for i := min; i <= max; i += step { - bits |= 1 << i - } - return bits -} - -// all returns all bits within the given bounds. (plus the star bit) -func all(r bounds) uint64 { - return getBits(r.min, r.max, 1) | starBit -} - -func init() { - AdminTaskList = make(map[string]Tasker) - stop = make(chan bool) - changed = make(chan bool) -} diff --git a/toolbox/task_test.go b/toolbox/task_test.go deleted file mode 100644 index b63f4391..00000000 --- a/toolbox/task_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "errors" - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) - err := tk.Run() - if err != nil { - t.Fatal(err) - } - AddTask("taska", tk) - StartTask() - time.Sleep(6 * time.Second) - StopTask() -} - -func TestSpec(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) - - AddTask("tk1", tk1) - AddTask("tk2", tk2) - AddTask("tk3", tk3) - StartTask() - defer StopTask() - - select { - case <-time.After(200 * time.Second): - t.FailNow() - case <-wait(wg): - } -} - -func TestTask_Run(t *testing.T) { - cnt := -1 - task := func() error { - cnt++ - fmt.Printf("Hello, world! %d \n", cnt) - return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) - } - tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200; i++ { - e := tk.Run() - assert.NotNil(t, e) - } - - l := tk.Errlist - assert.Equal(t, 100, len(l)) - assert.Equal(t, "Hello, world! 100", l[0].errinfo) - assert.Equal(t, "Hello, world! 101", l[1].errinfo) -} - -func wait(wg *sync.WaitGroup) chan bool { - ch := make(chan bool) - go func() { - wg.Wait() - ch <- true - }() - return ch -} diff --git a/tree.go b/tree.go deleted file mode 100644 index 7fa3a7cb..00000000 --- a/tree.go +++ /dev/null @@ -1,590 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "path" - "regexp" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" -) - -var ( - allowSuffixExt = []string{".json", ".xml", ".html"} -) - -// Tree has three elements: FixRouter/wildcard/leaves -// fixRouter stores Fixed Router -// wildcard stores params -// leaves store the endpoint information -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Tree struct { - //prefix set for static router - prefix string - //search fix route first - fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard - wildcard *Tree - //if set, failure to match wildcard search - leaves []*leafInfo -} - -// NewTree return a new Tree -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewTree() *Tree { - return &Tree{} -} - -// AddTree will add tree to the exist Tree -// prefix should has no params -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) AddTree(prefix string, tree *Tree) { - t.addtree(splitPath(prefix), tree, nil, "") -} - -func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) { - if len(segments) == 0 { - panic("prefix should has path") - } - seg := segments[0] - iswild, params, regexpStr := splitSegment(seg) - // if it's ? meaning can igone this, so add one more rule for it - if len(params) > 0 && params[0] == ":" { - params = params[1:] - if len(segments[1:]) > 0 { - t.addtree(segments[1:], tree, append(wildcards, params...), reg) - } else { - filterTreeWithPrefix(tree, wildcards, reg) - } - } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr - if !iswild && utils.InSlice(":splat", wildcards) { - iswild = true - regexpStr = seg - } - //Rule: /user/:id/* - if seg == "*" && len(wildcards) > 0 && reg == "" { - regexpStr = "(.+)" - } - if len(segments) == 1 { - if iswild { - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "([^.]+).(.+)" - } else { - for _, w := range params { - if w == "." || w == ":" { - continue - } - regexpStr = "([^/]+)/" + regexpStr - } - } - } - reg = strings.Trim(reg+"/"+regexpStr, "/") - filterTreeWithPrefix(tree, append(wildcards, params...), reg) - t.wildcard = tree - } else { - reg = strings.Trim(reg+"/"+regexpStr, "/") - filterTreeWithPrefix(tree, append(wildcards, params...), reg) - tree.prefix = seg - t.fixrouters = append(t.fixrouters, tree) - } - return - } - - if iswild { - if t.wildcard == nil { - t.wildcard = NewTree() - } - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "([^.]+).(.+)" - params = params[1:] - } else { - for range params { - regexpStr = "([^/]+)/" + regexpStr - } - } - } else { - if seg == "*.*" { - params = params[1:] - } - } - reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/") - t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg) - } else { - subTree := NewTree() - subTree.prefix = seg - t.fixrouters = append(t.fixrouters, subTree) - subTree.addtree(segments[1:], tree, append(wildcards, params...), reg) - } -} - -func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { - for _, v := range t.fixrouters { - filterTreeWithPrefix(v, wildcards, reg) - } - if t.wildcard != nil { - filterTreeWithPrefix(t.wildcard, wildcards, reg) - } - for _, l := range t.leaves { - if reg != "" { - if l.regexps != nil { - l.wildcards = append(wildcards, l.wildcards...) - l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$") - } else { - for _, v := range l.wildcards { - if v == ":splat" { - reg = reg + "/(.+)" - } else { - reg = reg + "/([^/]+)" - } - } - l.regexps = regexp.MustCompile("^" + reg + "$") - l.wildcards = append(wildcards, l.wildcards...) - } - } else { - l.wildcards = append(wildcards, l.wildcards...) - if l.regexps != nil { - for _, w := range wildcards { - if w == ":splat" { - reg = "(.+)/" + reg - } else { - reg = "([^/]+)/" + reg - } - } - l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$") - } - } - } -} - -// AddRouter call addseg function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) AddRouter(pattern string, runObject interface{}) { - t.addseg(splitPath(pattern), runObject, nil, "") -} - -// "/" -// "admin" -> -func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { - if len(segments) == 0 { - if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) - } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) - } - } else { - seg := segments[0] - iswild, params, regexpStr := splitSegment(seg) - // if it's ? meaning can igone this, so add one more rule for it - if len(params) > 0 && params[0] == ":" { - t.addseg(segments[1:], route, wildcards, reg) - params = params[1:] - } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr - if !iswild && utils.InSlice(":splat", wildcards) { - iswild = true - regexpStr = seg - } - //Rule: /user/:id/* - if seg == "*" && len(wildcards) > 0 && reg == "" { - regexpStr = "(.+)" - } - if iswild { - if t.wildcard == nil { - t.wildcard = NewTree() - } - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "/([^.]+).(.+)" - params = params[1:] - } else { - for range params { - regexpStr = "/([^/]+)" + regexpStr - } - } - } else { - if seg == "*.*" { - params = params[1:] - } - } - t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr) - } else { - var subTree *Tree - for _, sub := range t.fixrouters { - if sub.prefix == seg { - subTree = sub - break - } - } - if subTree == nil { - subTree = NewTree() - subTree.prefix = seg - t.fixrouters = append(t.fixrouters, subTree) - } - subTree.addseg(segments[1:], route, wildcards, reg) - } - } -} - -// Match router to runObject & params -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { - if len(pattern) == 0 || pattern[0] != '/' { - return nil - } - w := make([]string, 0, 20) - return t.match(pattern[1:], pattern, w, ctx) -} - -func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { - if len(pattern) > 0 { - i := 0 - for ; i < len(pattern) && pattern[i] == '/'; i++ { - } - pattern = pattern[i:] - } - // Handle leaf nodes: - if len(pattern) == 0 { - for _, l := range t.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - if t.wildcard != nil { - for _, l := range t.wildcard.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - } - return nil - } - var seg string - i, l := 0, len(pattern) - for ; i < l && pattern[i] != '/'; i++ { - } - if i == 0 { - seg = pattern - pattern = "" - } else { - seg = pattern[:i] - pattern = pattern[i:] - } - for _, subTree := range t.fixrouters { - if subTree.prefix == seg { - if len(pattern) != 0 && pattern[0] == '/' { - treePattern = pattern[1:] - } else { - treePattern = pattern - } - runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) - if runObject != nil { - break - } - } - } - if runObject == nil && len(t.fixrouters) > 0 { - // Filter the .json .xml .html extension - for _, str := range allowSuffixExt { - if strings.HasSuffix(seg, str) { - for _, subTree := range t.fixrouters { - if subTree.prefix == seg[:len(seg)-len(str)] { - runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) - if runObject != nil { - ctx.Input.SetParam(":ext", str[1:]) - } - } - } - } - } - } - if runObject == nil && t.wildcard != nil { - runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) - } - - if runObject == nil && len(t.leaves) > 0 { - wildcardValues = append(wildcardValues, seg) - start, i := 0, 0 - for ; i < len(pattern); i++ { - if pattern[i] == '/' { - if i != 0 && start < len(pattern) { - wildcardValues = append(wildcardValues, pattern[start:i]) - } - start = i + 1 - continue - } - } - if start > 0 { - wildcardValues = append(wildcardValues, pattern[start:i]) - } - for _, l := range t.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - } - return runObject -} - -type leafInfo struct { - // names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name" - wildcards []string - - // if the leaf is regexp - regexps *regexp.Regexp - - runObject interface{} -} - -func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) - if leaf.regexps == nil { - if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path - return true - } - // match * - if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { - ctx.Input.SetParam(":splat", treePattern) - return true - } - // match *.* or :id - if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" { - if len(leaf.wildcards) == 2 { - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - ctx.Input.SetParam(":ext", strs[1]) - } - ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0])) - return true - } else if len(wildcardValues) < 2 { - return false - } - var index int - for index = 0; index < len(leaf.wildcards)-2; index++ { - ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index]) - } - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - ctx.Input.SetParam(":ext", strs[1]) - } - if index > (len(wildcardValues) - 1) { - ctx.Input.SetParam(":path", "") - } else { - ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0])) - } - return true - } - // match :id - if len(leaf.wildcards) != len(wildcardValues) { - return false - } - for j, v := range leaf.wildcards { - ctx.Input.SetParam(v, wildcardValues[j]) - } - return true - } - - if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { - return false - } - matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) - for i, match := range matches[1:] { - if i < len(leaf.wildcards) { - ctx.Input.SetParam(leaf.wildcards[i], match) - } - } - return true -} - -// "/" -> [] -// "/admin" -> ["admin"] -// "/admin/" -> ["admin"] -// "/admin/users" -> ["admin", "users"] -func splitPath(key string) []string { - key = strings.Trim(key, "/ ") - if key == "" { - return []string{} - } - return strings.Split(key, "/") -} - -// "admin" -> false, nil, "" -// ":id" -> true, [:id], "" -// "?:id" -> true, [: :id], "" : meaning can empty -// ":id:int" -> true, [:id], ([0-9]+) -// ":name:string" -> true, [:name], ([\w]+) -// ":id([0-9]+)" -> true, [:id], ([0-9]+) -// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) -// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html -// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html -// "*" -> true, [:splat], "" -// "*.*" -> true,[. :path :ext], "" . meaning separator -func splitSegment(key string) (bool, []string, string) { - if strings.HasPrefix(key, "*") { - if key == "*.*" { - return true, []string{".", ":path", ":ext"}, "" - } - return true, []string{":splat"}, "" - } - if strings.ContainsAny(key, ":") { - var paramsNum int - var out []rune - var start bool - var startexp bool - var param []rune - var expt []rune - var skipnum int - params := []string{} - reg := regexp.MustCompile(`[a-zA-Z0-9_]+`) - for i, v := range key { - if skipnum > 0 { - skipnum-- - continue - } - if start { - //:id:int and :name:string - if v == ':' { - if len(key) >= i+4 { - if key[i+1:i+4] == "int" { - out = append(out, []rune("([0-9]+)")...) - params = append(params, ":"+string(param)) - start = false - startexp = false - skipnum = 3 - param = make([]rune, 0) - paramsNum++ - continue - } - } - if len(key) >= i+7 { - if key[i+1:i+7] == "string" { - out = append(out, []rune(`([\w]+)`)...) - params = append(params, ":"+string(param)) - paramsNum++ - start = false - startexp = false - skipnum = 6 - param = make([]rune, 0) - continue - } - } - } - // params only support a-zA-Z0-9 - if reg.MatchString(string(v)) { - param = append(param, v) - continue - } - if v != '(' { - out = append(out, []rune(`(.+)`)...) - params = append(params, ":"+string(param)) - param = make([]rune, 0) - paramsNum++ - start = false - startexp = false - } - } - if startexp { - if v != ')' { - expt = append(expt, v) - continue - } - } - // Escape Sequence '\' - if i > 0 && key[i-1] == '\\' { - out = append(out, v) - } else if v == ':' { - param = make([]rune, 0) - start = true - } else if v == '(' { - startexp = true - start = false - if len(param) > 0 { - params = append(params, ":"+string(param)) - param = make([]rune, 0) - } - paramsNum++ - expt = make([]rune, 0) - expt = append(expt, '(') - } else if v == ')' { - startexp = false - expt = append(expt, ')') - out = append(out, expt...) - param = make([]rune, 0) - } else if v == '?' { - params = append(params, ":") - } else { - out = append(out, v) - } - } - if len(param) > 0 { - if paramsNum > 0 { - out = append(out, []rune(`(.+)`)...) - } - params = append(params, ":"+string(param)) - } - return true, params, string(out) - } - return false, nil, "" -} diff --git a/tree_test.go b/tree_test.go deleted file mode 100644 index d412a348..00000000 --- a/tree_test.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - "testing" - - "github.com/astaxie/beego/context" -) - -type testinfo struct { - url string - requesturl string - params map[string]string -} - -var routers []testinfo - -func init() { - routers = make([]testinfo, 0) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) - routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) - routers = append(routers, testinfo{"/", "/", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) - routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) - routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) - routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) - routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) - routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) - routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) - routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", - "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", - map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) - routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", - "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", - map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) - routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) - routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) -} - -func TestTreeRouters(t *testing.T) { - for _, r := range routers { - tr := NewTree() - tr.AddRouter(r.url, "astaxie") - ctx := context.NewContext() - obj := tr.Match(r.requesturl, ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) - } - if r.params != nil { - for k, v := range r.params { - if vv := ctx.Input.Param(k); vv != v { - t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) - } else if vv == "" && v != "" { - t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) - } - } - } - } -} - -func TestStaticPath(t *testing.T) { - tr := NewTree() - tr.AddRouter("/topic/:id", "wildcard") - tr.AddRouter("/topic", "static") - ctx := context.NewContext() - obj := tr.Match("/topic", ctx) - if obj == nil || obj.(string) != "static" { - t.Fatal("/topic is a static route") - } - obj = tr.Match("/topic/1", ctx) - if obj == nil || obj.(string) != "wildcard" { - t.Fatal("/topic/1 is a wildcard route") - } -} - -func TestAddTree(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t1 := NewTree() - t1.AddTree("/v1/zl", tr) - ctx := context.NewContext() - obj := t1.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" { - t.Fatal("get :id param error") - } - ctx.Input.Reset(ctx) - obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { - t.Fatal("get :sd :id :page param error") - } - - t2 := NewTree() - t2.AddTree("/v1/:shopid", tr) - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :id :shopid param error") - } - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get :shopid param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :sd :id :page :shopid param error") - } -} - -func TestAddTree2(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t3 := NewTree() - t3.AddTree("/:version(v1|v2)/:prefix", tr) - ctx := context.NewContext() - obj := t3.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { - t.Fatal("get :id :prefix :version param error") - } -} - -func TestAddTree3(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/account", "astaxie") - t3 := NewTree() - t3.AddTree("/table/:num", tr) - ctx := context.NewContext() - obj := t3.Match("/table/123/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/shop/:sd/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { - t.Fatal("get :num :sd param error") - } - ctx.Input.Reset(ctx) - obj = t3.Match("/table/123/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/create can't get obj ") - } -} - -func TestAddTree4(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/:account", "astaxie") - t4 := NewTree() - t4.AddTree("/:info:int/:num/:id", tr) - ctx := context.NewContext() - obj := t4.Match("/12/123/456/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || - ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || - ctx.Input.Param(":account") != "account" { - t.Fatal("get :info :num :id :sd :account param error") - } - ctx.Input.Reset(ctx) - obj = t4.Match("/12/123/456/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/create can't get obj ") - } -} - -// Test for issue #1595 -func TestAddTree5(t *testing.T) { - tr := NewTree() - tr.AddRouter("/v1/shop/:id", "shopdetail") - tr.AddRouter("/v1/shop/", "shophome") - ctx := context.NewContext() - obj := tr.Match("/v1/shop/", ctx) - if obj == nil || obj.(string) != "shophome" { - t.Fatal("url /v1/shop/ need match router /v1/shop/ ") - } -} - -func TestSplitPath(t *testing.T) { - a := splitPath("") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/admin") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin should retrun [admin]") - } - a = splitPath("/admin/") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin/ should retrun [admin]") - } - a = splitPath("/admin/users") - if len(a) != 2 || a[0] != "admin" || a[1] != "users" { - t.Fatal("/admin should retrun [admin users]") - } - a = splitPath("/admin/:id:int") - if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { - t.Fatal("/admin should retrun [admin :id:int]") - } -} - -func TestSplitSegment(t *testing.T) { - - items := map[string]struct { - isReg bool - params []string - regStr string - }{ - "admin": {false, nil, ""}, - "*": {true, []string{":splat"}, ""}, - "*.*": {true, []string{".", ":path", ":ext"}, ""}, - ":id": {true, []string{":id"}, ""}, - "?:id": {true, []string{":", ":id"}, ""}, - ":id:int": {true, []string{":id"}, "([0-9]+)"}, - ":name:string": {true, []string{":name"}, `([\w]+)`}, - ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, - ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, - ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, - "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, - `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, - `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, - } - - for pattern, v := range items { - b, w, r := splitSegment(pattern) - if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { - t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) - } - } -} diff --git a/unregroute_test.go b/unregroute_test.go deleted file mode 100644 index 08b1b77b..00000000 --- a/unregroute_test.go +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -// -// The unregroute_test.go contains tests for the unregister route -// functionality, that allows overriding route paths in children project -// that embed parent routers. -// - -const contentRootOriginal = "ok-original-root" -const contentLevel1Original = "ok-original-level1" -const contentLevel2Original = "ok-original-level2" - -const contentRootReplacement = "ok-replacement-root" -const contentLevel1Replacement = "ok-replacement-level1" -const contentLevel2Replacement = "ok-replacement-level2" - -// TestPreUnregController will supply content for the original routes, -// before unregistration -type TestPreUnregController struct { - Controller -} - -func (tc *TestPreUnregController) GetFixedRoot() { - tc.Ctx.Output.Body([]byte(contentRootOriginal)) -} -func (tc *TestPreUnregController) GetFixedLevel1() { - tc.Ctx.Output.Body([]byte(contentLevel1Original)) -} -func (tc *TestPreUnregController) GetFixedLevel2() { - tc.Ctx.Output.Body([]byte(contentLevel2Original)) -} - -// TestPostUnregController will supply content for the overriding routes, -// after the original ones are unregistered. -type TestPostUnregController struct { - Controller -} - -func (tc *TestPostUnregController) GetFixedRoot() { - tc.Ctx.Output.Body([]byte(contentRootReplacement)) -} -func (tc *TestPostUnregController) GetFixedLevel1() { - tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) -} -func (tc *TestPostUnregController) GetFixedLevel2() { - tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) -} - -// TestUnregisterFixedRouteRoot replaces just the root fixed route path. -// In this case, for a path like "/level1/level2" or "/level1", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteRoot(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, "Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, "Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, "Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the root path - findAndRemoveSingleTree(handler.routers[method]) - - // Replace the root path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") - - // Test replacement root (expect change) - testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) - - // Test level 1 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) - - // Test level 2 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - -} - -// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. -// In this case, for a path like "/level1/level2" or "/", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteLevel1(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the level1 path - subPaths := splitPath("/level1") - if handler.routers[method].prefix == strings.Trim("/level1", "/ ") { - findAndRemoveSingleTree(handler.routers[method]) - } else { - findAndRemoveTree(subPaths, handler.routers[method], method) - } - - // Replace the "level1" path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") - - // Test replacement root (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) - - // Test level 1 (expect change) - testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement) - - // Test level 2 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - -} - -// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed -// route path. In this case, for a path like "/level1" or "/", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteLevel2(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the level2 path - subPaths := splitPath("/level1/level2") - if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") { - findAndRemoveSingleTree(handler.routers[method]) - } else { - findAndRemoveTree(subPaths, handler.routers[method], method) - } - - // Replace the "/level1/level2" path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") - - // Test replacement root (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) - - // Test level 1 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) - - // Test level 2 (expect change) - testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) - -} - -func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, - testName, method, path, expectedBodyContent string) { - - r, err := http.NewRequest(method, path, nil) - if err != nil { - t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) - return - } - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - body := w.Body.String() - if body != expectedBodyContent { - t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body) - } -} diff --git a/utils/caller.go b/utils/caller.go deleted file mode 100644 index 73c52a62..00000000 --- a/utils/caller.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "reflect" - "runtime" -) - -// GetFuncName get function name -func GetFuncName(i interface{}) string { - return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() -} diff --git a/utils/caller_test.go b/utils/caller_test.go deleted file mode 100644 index 0675f0aa..00000000 --- a/utils/caller_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "strings" - "testing" -) - -func TestGetFuncName(t *testing.T) { - name := GetFuncName(TestGetFuncName) - t.Log(name) - if !strings.HasSuffix(name, ".TestGetFuncName") { - t.Error("get func name error") - } -} diff --git a/utils/captcha/LICENSE b/utils/captcha/LICENSE deleted file mode 100644 index 0ad73ae0..00000000 --- a/utils/captcha/LICENSE +++ /dev/null @@ -1,19 +0,0 @@ -Copyright (c) 2011-2014 Dmitry Chestnykh - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/utils/captcha/README.md b/utils/captcha/README.md deleted file mode 100644 index dbc2026b..00000000 --- a/utils/captcha/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Captcha - -an example for use captcha - -``` -package controllers - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/utils/captcha" -) - -var cpt *captcha.Captcha - -func init() { - // use beego cache system store the captcha data - store := cache.NewMemoryCache() - cpt = captcha.NewWithFilter("/captcha/", store) -} - -type MainController struct { - beego.Controller -} - -func (this *MainController) Get() { - this.TplName = "index.tpl" -} - -func (this *MainController) Post() { - this.TplName = "index.tpl" - - this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) -} -``` - -template usage - -``` -{{.Success}} -
- {{create_captcha}} - -
-``` diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go deleted file mode 100644 index 42ac70d3..00000000 --- a/utils/captcha/captcha.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package captcha implements generation and verification of image CAPTCHAs. -// an example for use captcha -// -// ``` -// package controllers -// -// import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/cache" -// "github.com/astaxie/beego/utils/captcha" -// ) -// -// var cpt *captcha.Captcha -// -// func init() { -// // use beego cache system store the captcha data -// store := cache.NewMemoryCache() -// cpt = captcha.NewWithFilter("/captcha/", store) -// } -// -// type MainController struct { -// beego.Controller -// } -// -// func (this *MainController) Get() { -// this.TplName = "index.tpl" -// } -// -// func (this *MainController) Post() { -// this.TplName = "index.tpl" -// -// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) -// } -// ``` -// -// template usage -// -// ``` -// {{.Success}} -//
-// {{create_captcha}} -// -//
-// ``` -package captcha - -import ( - "fmt" - "html/template" - "net/http" - "path" - "strings" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var ( - defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} -) - -const ( - // default captcha attributes - challengeNums = 6 - expiration = 600 * time.Second - fieldIDName = "captcha_id" - fieldCaptchaName = "captcha" - cachePrefix = "captcha_" - defaultURLPrefix = "/captcha/" -) - -// Captcha struct -type Captcha struct { - // beego cache store - store cache.Cache - - // url prefix for captcha image - URLPrefix string - - // specify captcha id input field name - FieldIDName string - // specify captcha result input field name - FieldCaptchaName string - - // captcha image width and height - StdWidth int - StdHeight int - - // captcha chars nums - ChallengeNums int - - // captcha expiration seconds - Expiration time.Duration - - // cache key prefix - CachePrefix string -} - -// generate key string -func (c *Captcha) key(id string) string { - return c.CachePrefix + id -} - -// generate rand chars with default chars -func (c *Captcha) genRandChars() []byte { - return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) -} - -// Handler beego filter handler for serve captcha image -func (c *Captcha) Handler(ctx *context.Context) { - var chars []byte - - id := path.Base(ctx.Request.RequestURI) - if i := strings.Index(id, "."); i != -1 { - id = id[:i] - } - - key := c.key(id) - - if len(ctx.Input.Query("reload")) > 0 { - chars = c.genRandChars() - if err := c.store.Put(key, chars, c.Expiration); err != nil { - ctx.Output.SetStatus(500) - ctx.WriteString("captcha reload error") - logs.Error("Reload Create Captcha Error:", err) - return - } - } else { - if v, ok := c.store.Get(key).([]byte); ok { - chars = v - } else { - ctx.Output.SetStatus(404) - ctx.WriteString("captcha not found") - return - } - } - - img := NewImage(chars, c.StdWidth, c.StdHeight) - if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { - logs.Error("Write Captcha Image Error:", err) - } -} - -// CreateCaptchaHTML template func for output html -func (c *Captcha) CreateCaptchaHTML() template.HTML { - value, err := c.CreateCaptcha() - if err != nil { - logs.Error("Create Captcha Error:", err) - return "" - } - - // create html - return template.HTML(fmt.Sprintf(``+ - ``+ - ``+ - ``, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value)) -} - -// CreateCaptcha create a new captcha id -func (c *Captcha) CreateCaptcha() (string, error) { - // generate captcha id - id := string(utils.RandomCreateBytes(15)) - - // get the captcha chars - chars := c.genRandChars() - - // save to store - if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { - return "", err - } - - return id, nil -} - -// VerifyReq verify from a request -func (c *Captcha) VerifyReq(req *http.Request) bool { - req.ParseForm() - return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName)) -} - -// Verify direct verify id and challenge string -func (c *Captcha) Verify(id string, challenge string) (success bool) { - if len(challenge) == 0 || len(id) == 0 { - return - } - - var chars []byte - - key := c.key(id) - - if v, ok := c.store.Get(key).([]byte); ok { - chars = v - } else { - return - } - - defer func() { - // finally remove it - c.store.Delete(key) - }() - - if len(chars) != len(challenge) { - return - } - // verify challenge - for i, c := range chars { - if c != challenge[i]-48 { - return - } - } - - return true -} - -// NewCaptcha create a new captcha.Captcha -func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { - cpt := &Captcha{} - cpt.store = store - cpt.FieldIDName = fieldIDName - cpt.FieldCaptchaName = fieldCaptchaName - cpt.ChallengeNums = challengeNums - cpt.Expiration = expiration - cpt.CachePrefix = cachePrefix - cpt.StdWidth = stdWidth - cpt.StdHeight = stdHeight - - if len(urlPrefix) == 0 { - urlPrefix = defaultURLPrefix - } - - if urlPrefix[len(urlPrefix)-1] != '/' { - urlPrefix += "/" - } - - cpt.URLPrefix = urlPrefix - - return cpt -} - -// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image -// and add a template func for output html -func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { - cpt := NewCaptcha(urlPrefix, store) - - // create filter for serve captcha image - beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) - - // add to template func map - beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) - - return cpt -} diff --git a/utils/captcha/image.go b/utils/captcha/image.go deleted file mode 100644 index c3c9a83a..00000000 --- a/utils/captcha/image.go +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "bytes" - "image" - "image/color" - "image/png" - "io" - "math" -) - -const ( - fontWidth = 11 - fontHeight = 18 - blackChar = 1 - - // Standard width and height of a captcha image. - stdWidth = 240 - stdHeight = 80 - // Maximum absolute skew factor of a single digit. - maxSkew = 0.7 - // Number of background circles. - circleCount = 20 -) - -var font = [][]byte{ - { // 0 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 1 - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - }, - { // 2 - 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - }, - { // 3 - 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 4 - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, - 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - }, - { // 5 - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 6 - 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, - 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, - 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 7 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - }, - { // 8 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 9 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, - 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, - }, -} - -// Image struct -type Image struct { - *image.Paletted - numWidth int - numHeight int - dotSize int -} - -var prng = &siprng{} - -// randIntn returns a pseudorandom non-negative int in range [0, n). -func randIntn(n int) int { - return prng.Intn(n) -} - -// randInt returns a pseudorandom int in range [from, to]. -func randInt(from, to int) int { - return prng.Intn(to+1-from) + from -} - -// randFloat returns a pseudorandom float64 in range [from, to]. -func randFloat(from, to float64) float64 { - return (to-from)*prng.Float64() + from -} - -func randomPalette() color.Palette { - p := make([]color.Color, circleCount+1) - // Transparent color. - p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} - // Primary color. - prim := color.RGBA{ - uint8(randIntn(129)), - uint8(randIntn(129)), - uint8(randIntn(129)), - 0xFF, - } - p[1] = prim - // Circle colors. - for i := 2; i <= circleCount; i++ { - p[i] = randomBrightness(prim, 255) - } - return p -} - -// NewImage returns a new captcha image of the given width and height with the -// given digits, where each digit must be in range 0-9. -func NewImage(digits []byte, width, height int) *Image { - m := new(Image) - m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) - m.calculateSizes(width, height, len(digits)) - // Randomly position captcha inside the image. - maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize - maxy := height - m.numHeight - m.dotSize*2 - var border int - if width > height { - border = height / 5 - } else { - border = width / 5 - } - x := randInt(border, maxx-border) - y := randInt(border, maxy-border) - // Draw digits. - for _, n := range digits { - m.drawDigit(font[n], x, y) - x += m.numWidth + m.dotSize - } - // Draw strike-through line. - m.strikeThrough() - // Apply wave distortion. - m.distort(randFloat(5, 10), randFloat(100, 200)) - // Fill image with random circles. - m.fillWithCircles(circleCount, m.dotSize) - return m -} - -// encodedPNG encodes an image to PNG and returns -// the result as a byte slice. -func (m *Image) encodedPNG() []byte { - var buf bytes.Buffer - if err := png.Encode(&buf, m.Paletted); err != nil { - panic(err.Error()) - } - return buf.Bytes() -} - -// WriteTo writes captcha image in PNG format into the given writer. -func (m *Image) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(m.encodedPNG()) - return int64(n), err -} - -func (m *Image) calculateSizes(width, height, ncount int) { - // Goal: fit all digits inside the image. - var border int - if width > height { - border = height / 4 - } else { - border = width / 4 - } - // Convert everything to floats for calculations. - w := float64(width - border*2) - h := float64(height - border*2) - // fw takes into account 1-dot spacing between digits. - fw := float64(fontWidth + 1) - fh := float64(fontHeight) - nc := float64(ncount) - // Calculate the width of a single digit taking into account only the - // width of the image. - nw := w / nc - // Calculate the height of a digit from this width. - nh := nw * fh / fw - // Digit too high? - if nh > h { - // Fit digits based on height. - nh = h - nw = fw / fh * nh - } - // Calculate dot size. - m.dotSize = int(nh / fh) - if m.dotSize < 1 { - m.dotSize = 1 - } - // Save everything, making the actual width smaller by 1 dot to account - // for spacing between digits. - m.numWidth = int(nw) - m.dotSize - m.numHeight = int(nh) -} - -func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { - for x := fromX; x <= toX; x++ { - m.SetColorIndex(x, y, colorIdx) - } -} - -func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { - f := 1 - radius - dfx := 1 - dfy := -2 * radius - xo := 0 - yo := radius - - m.SetColorIndex(x, y+radius, colorIdx) - m.SetColorIndex(x, y-radius, colorIdx) - m.drawHorizLine(x-radius, x+radius, y, colorIdx) - - for xo < yo { - if f >= 0 { - yo-- - dfy += 2 - f += dfy - } - xo++ - dfx += 2 - f += dfx - m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) - m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) - m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) - m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) - } -} - -func (m *Image) fillWithCircles(n, maxradius int) { - maxx := m.Bounds().Max.X - maxy := m.Bounds().Max.Y - for i := 0; i < n; i++ { - colorIdx := uint8(randInt(1, circleCount-1)) - r := randInt(1, maxradius) - m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) - } -} - -func (m *Image) strikeThrough() { - maxx := m.Bounds().Max.X - maxy := m.Bounds().Max.Y - y := randInt(maxy/3, maxy-maxy/3) - amplitude := randFloat(5, 20) - period := randFloat(80, 180) - dx := 2.0 * math.Pi / period - for x := 0; x < maxx; x++ { - xo := amplitude * math.Cos(float64(y)*dx) - yo := amplitude * math.Sin(float64(x)*dx) - for yn := 0; yn < m.dotSize; yn++ { - r := randInt(0, m.dotSize) - m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) - } - } -} - -func (m *Image) drawDigit(digit []byte, x, y int) { - skf := randFloat(-maxSkew, maxSkew) - xs := float64(x) - r := m.dotSize / 2 - y += randInt(-r, r) - for yo := 0; yo < fontHeight; yo++ { - for xo := 0; xo < fontWidth; xo++ { - if digit[yo*fontWidth+xo] != blackChar { - continue - } - m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) - } - xs += skf - x = int(xs) - } -} - -func (m *Image) distort(amplude float64, period float64) { - w := m.Bounds().Max.X - h := m.Bounds().Max.Y - - oldm := m.Paletted - newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) - - dx := 2.0 * math.Pi / period - for x := 0; x < w; x++ { - for y := 0; y < h; y++ { - xo := amplude * math.Sin(float64(y)*dx) - yo := amplude * math.Cos(float64(x)*dx) - newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) - } - } - m.Paletted = newm -} - -func randomBrightness(c color.RGBA, max uint8) color.RGBA { - minc := min3(c.R, c.G, c.B) - maxc := max3(c.R, c.G, c.B) - if maxc > max { - return c - } - n := randIntn(int(max-maxc)) - int(minc) - return color.RGBA{ - uint8(int(c.R) + n), - uint8(int(c.G) + n), - uint8(int(c.B) + n), - c.A, - } -} - -func min3(x, y, z uint8) (m uint8) { - m = x - if y < m { - m = y - } - if z < m { - m = z - } - return -} - -func max3(x, y, z uint8) (m uint8) { - m = x - if y > m { - m = y - } - if z > m { - m = z - } - return -} diff --git a/utils/captcha/image_test.go b/utils/captcha/image_test.go deleted file mode 100644 index 5e35b7f7..00000000 --- a/utils/captcha/image_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "testing" - - "github.com/astaxie/beego/utils" -) - -type byteCounter struct { - n int64 -} - -func (bc *byteCounter) Write(b []byte) (int, error) { - bc.n += int64(len(b)) - return len(b), nil -} - -func BenchmarkNewImage(b *testing.B) { - b.StopTimer() - d := utils.RandomCreateBytes(challengeNums, defaultChars...) - b.StartTimer() - for i := 0; i < b.N; i++ { - NewImage(d, stdWidth, stdHeight) - } -} - -func BenchmarkImageWriteTo(b *testing.B) { - b.StopTimer() - d := utils.RandomCreateBytes(challengeNums, defaultChars...) - b.StartTimer() - counter := &byteCounter{} - for i := 0; i < b.N; i++ { - img := NewImage(d, stdWidth, stdHeight) - img.WriteTo(counter) - b.SetBytes(counter.n) - counter.n = 0 - } -} diff --git a/utils/captcha/siprng.go b/utils/captcha/siprng.go deleted file mode 100644 index 5e256cf9..00000000 --- a/utils/captcha/siprng.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "crypto/rand" - "encoding/binary" - "io" - "sync" -) - -// siprng is PRNG based on SipHash-2-4. -type siprng struct { - mu sync.Mutex - k0, k1, ctr uint64 -} - -// siphash implements SipHash-2-4, accepting a uint64 as a message. -func siphash(k0, k1, m uint64) uint64 { - // Initialization. - v0 := k0 ^ 0x736f6d6570736575 - v1 := k1 ^ 0x646f72616e646f6d - v2 := k0 ^ 0x6c7967656e657261 - v3 := k1 ^ 0x7465646279746573 - t := uint64(8) << 56 - - // Compression. - v3 ^= m - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - v0 ^= m - - // Compress last block. - v3 ^= t - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - v0 ^= t - - // Finalization. - v2 ^= 0xff - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 3. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 4. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - return v0 ^ v1 ^ v2 ^ v3 -} - -// rekey sets a new PRNG key, which is read from crypto/rand. -func (p *siprng) rekey() { - var k [16]byte - if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { - panic(err.Error()) - } - p.k0 = binary.LittleEndian.Uint64(k[0:8]) - p.k1 = binary.LittleEndian.Uint64(k[8:16]) - p.ctr = 1 -} - -// Uint64 returns a new pseudorandom uint64. -// It rekeys PRNG on the first call and every 64 MB of generated data. -func (p *siprng) Uint64() uint64 { - p.mu.Lock() - if p.ctr == 0 || p.ctr > 8*1024*1024 { - p.rekey() - } - v := siphash(p.k0, p.k1, p.ctr) - p.ctr++ - p.mu.Unlock() - return v -} - -func (p *siprng) Int63() int64 { - return int64(p.Uint64() & 0x7fffffffffffffff) -} - -func (p *siprng) Uint32() uint32 { - return uint32(p.Uint64()) -} - -func (p *siprng) Int31() int32 { - return int32(p.Uint32() & 0x7fffffff) -} - -func (p *siprng) Intn(n int) int { - if n <= 0 { - panic("invalid argument to Intn") - } - if n <= 1<<31-1 { - return int(p.Int31n(int32(n))) - } - return int(p.Int63n(int64(n))) -} - -func (p *siprng) Int63n(n int64) int64 { - if n <= 0 { - panic("invalid argument to Int63n") - } - max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) - v := p.Int63() - for v > max { - v = p.Int63() - } - return v % n -} - -func (p *siprng) Int31n(n int32) int32 { - if n <= 0 { - panic("invalid argument to Int31n") - } - max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) - v := p.Int31() - for v > max { - v = p.Int31() - } - return v % n -} - -func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/utils/captcha/siprng_test.go b/utils/captcha/siprng_test.go deleted file mode 100644 index 189d3d3c..00000000 --- a/utils/captcha/siprng_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import "testing" - -func TestSiphash(t *testing.T) { - good := uint64(0xe849e8bb6ffe2567) - cur := siphash(0, 0, 0) - if cur != good { - t.Fatalf("siphash: expected %x, got %x", good, cur) - } -} - -func BenchmarkSiprng(b *testing.B) { - b.SetBytes(8) - p := &siprng{} - for i := 0; i < b.N; i++ { - p.Uint64() - } -} diff --git a/utils/debug.go b/utils/debug.go deleted file mode 100644 index 93c27b70..00000000 --- a/utils/debug.go +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bytes" - "fmt" - "log" - "reflect" - "runtime" -) - -var ( - dunno = []byte("???") - centerDot = []byte("·") - dot = []byte(".") -) - -type pointerInfo struct { - prev *pointerInfo - n int - addr uintptr - pos int - used []int -} - -// Display print the data in console -func Display(data ...interface{}) { - display(true, data...) -} - -// GetDisplayString return data print string -func GetDisplayString(data ...interface{}) string { - return display(false, data...) -} - -func display(displayed bool, data ...interface{}) string { - var pc, file, line, ok = runtime.Caller(2) - - if !ok { - return "" - } - - var buf = new(bytes.Buffer) - - fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) - - fmt.Fprintf(buf, "\n[Variables]\n") - - for i := 0; i < len(data); i += 2 { - var output = fomateinfo(len(data[i].(string))+3, data[i+1]) - fmt.Fprintf(buf, "%s = %s", data[i], output) - } - - if displayed { - log.Print(buf) - } - return buf.String() -} - -// return data dump and format bytes -func fomateinfo(headlen int, data ...interface{}) []byte { - var buf = new(bytes.Buffer) - - if len(data) > 1 { - fmt.Fprint(buf, " ") - - fmt.Fprint(buf, "[") - - fmt.Fprintln(buf) - } - - for k, v := range data { - var buf2 = new(bytes.Buffer) - var pointers *pointerInfo - var interfaces = make([]reflect.Value, 0, 10) - - printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) - - if k < len(data)-1 { - fmt.Fprint(buf2, ", ") - } - - fmt.Fprintln(buf2) - - buf.Write(buf2.Bytes()) - } - - if len(data) > 1 { - fmt.Fprintln(buf) - - fmt.Fprint(buf, " ") - - fmt.Fprint(buf, "]") - } - - return buf.Bytes() -} - -// check data is golang basic type -func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { - switch kind { - case reflect.Bool: - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return true - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - return true - case reflect.Float32, reflect.Float64: - return true - case reflect.Complex64, reflect.Complex128: - return true - case reflect.String: - return true - case reflect.Chan: - return true - case reflect.Invalid: - return true - case reflect.Interface: - for _, in := range *interfaces { - if reflect.DeepEqual(in, val) { - return true - } - } - return false - case reflect.UnsafePointer: - if val.IsNil() { - return true - } - - var elem = val.Elem() - - if isSimpleType(elem, elem.Kind(), pointers, interfaces) { - return true - } - - var addr = val.Elem().UnsafeAddr() - - for p := *pointers; p != nil; p = p.prev { - if addr == p.addr { - return true - } - } - - return false - } - - return false -} - -// dump value -func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { - var t = val.Kind() - - switch t { - case reflect.Bool: - fmt.Fprint(buf, val.Bool()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fmt.Fprint(buf, val.Int()) - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - fmt.Fprint(buf, val.Uint()) - case reflect.Float32, reflect.Float64: - fmt.Fprint(buf, val.Float()) - case reflect.Complex64, reflect.Complex128: - fmt.Fprint(buf, val.Complex()) - case reflect.UnsafePointer: - fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer()) - case reflect.Ptr: - if val.IsNil() { - fmt.Fprint(buf, "nil") - return - } - - var addr = val.Elem().UnsafeAddr() - - for p := *pointers; p != nil; p = p.prev { - if addr == p.addr { - p.used = append(p.used, buf.Len()) - fmt.Fprintf(buf, "0x%X", addr) - return - } - } - - *pointers = &pointerInfo{ - prev: *pointers, - addr: addr, - pos: buf.Len(), - used: make([]int, 0), - } - - fmt.Fprint(buf, "&") - - printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level) - case reflect.String: - fmt.Fprint(buf, "\"", val.String(), "\"") - case reflect.Interface: - var value = val.Elem() - - if !value.IsValid() { - fmt.Fprint(buf, "nil") - } else { - for _, in := range *interfaces { - if reflect.DeepEqual(in, val) { - fmt.Fprint(buf, "repeat") - return - } - } - - *interfaces = append(*interfaces, val) - - printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) - } - case reflect.Struct: - var t = val.Type() - - fmt.Fprint(buf, t) - fmt.Fprint(buf, "{") - - for i := 0; i < val.NumField(); i++ { - if formatOutput { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - var name = t.Field(i).Name - - if formatOutput { - for ind := 0; ind < level; ind++ { - fmt.Fprint(buf, indent) - } - } - - fmt.Fprint(buf, name) - fmt.Fprint(buf, ": ") - - if structFilter != nil && structFilter(t.String(), name) { - fmt.Fprint(buf, "ignore") - } else { - printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1) - } - - fmt.Fprint(buf, ",") - } - - if formatOutput { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Array, reflect.Slice: - fmt.Fprint(buf, val.Type()) - fmt.Fprint(buf, "{") - - var allSimple = true - - for i := 0; i < val.Len(); i++ { - var elem = val.Index(i) - - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) - - if !isSimple { - allSimple = false - } - - if formatOutput && !isSimple { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - if formatOutput && !isSimple { - for ind := 0; ind < level; ind++ { - fmt.Fprint(buf, indent) - } - } - - printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) - - if i != val.Len()-1 || !allSimple { - fmt.Fprint(buf, ",") - } - } - - if formatOutput && !allSimple { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Map: - var t = val.Type() - var keys = val.MapKeys() - - fmt.Fprint(buf, t) - fmt.Fprint(buf, "{") - - var allSimple = true - - for i := 0; i < len(keys); i++ { - var elem = val.MapIndex(keys[i]) - - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) - - if !isSimple { - allSimple = false - } - - if formatOutput && !isSimple { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - if formatOutput && !isSimple { - for ind := 0; ind <= level; ind++ { - fmt.Fprint(buf, indent) - } - } - - printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1) - fmt.Fprint(buf, ": ") - printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) - - if i != val.Len()-1 || !allSimple { - fmt.Fprint(buf, ",") - } - } - - if formatOutput && !allSimple { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Chan: - fmt.Fprint(buf, val.Type()) - case reflect.Invalid: - fmt.Fprint(buf, "invalid") - default: - fmt.Fprint(buf, "unknow") - } -} - -// PrintPointerInfo dump pointer value -func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { - var anyused = false - var pointerNum = 0 - - for p := pointers; p != nil; p = p.prev { - if len(p.used) > 0 { - anyused = true - } - pointerNum++ - p.n = pointerNum - } - - if anyused { - var pointerBufs = make([][]rune, pointerNum+1) - - for i := 0; i < len(pointerBufs); i++ { - var pointerBuf = make([]rune, buf.Len()+headlen) - - for j := 0; j < len(pointerBuf); j++ { - pointerBuf[j] = ' ' - } - - pointerBufs[i] = pointerBuf - } - - for pn := 0; pn <= pointerNum; pn++ { - for p := pointers; p != nil; p = p.prev { - if len(p.used) > 0 && p.n >= pn { - if pn == p.n { - pointerBufs[pn][p.pos+headlen] = '└' - - var maxpos = 0 - - for i, pos := range p.used { - if i < len(p.used)-1 { - pointerBufs[pn][pos+headlen] = '┴' - } else { - pointerBufs[pn][pos+headlen] = '┘' - } - - maxpos = pos - } - - for i := 0; i < maxpos-p.pos-1; i++ { - if pointerBufs[pn][i+p.pos+headlen+1] == ' ' { - pointerBufs[pn][i+p.pos+headlen+1] = '─' - } - } - } else { - pointerBufs[pn][p.pos+headlen] = '│' - - for _, pos := range p.used { - if pointerBufs[pn][pos+headlen] == ' ' { - pointerBufs[pn][pos+headlen] = '│' - } else { - pointerBufs[pn][pos+headlen] = '┼' - } - } - } - } - } - - buf.WriteString(string(pointerBufs[pn]) + "\n") - } - } -} - -// Stack get stack bytes -func Stack(skip int, indent string) []byte { - var buf = new(bytes.Buffer) - - for i := skip; ; i++ { - var pc, file, line, ok = runtime.Caller(i) - - if !ok { - break - } - - buf.WriteString(indent) - - fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line) - } - - return buf.Bytes() -} - -// return the name of the function containing the PC if possible, -func function(pc uintptr) []byte { - fn := runtime.FuncForPC(pc) - if fn == nil { - return dunno - } - name := []byte(fn.Name()) - // The name includes the path name to the package, which is unnecessary - // since the file name is already included. Plus, it has center dots. - // That is, we see - // runtime/debug.*T·ptrmethod - // and want - // *T.ptrmethod - if period := bytes.Index(name, dot); period >= 0 { - name = name[period+1:] - } - name = bytes.Replace(name, centerDot, dot, -1) - return name -} diff --git a/utils/debug_test.go b/utils/debug_test.go deleted file mode 100644 index efb8924e..00000000 --- a/utils/debug_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "testing" -) - -type mytype struct { - next *mytype - prev *mytype -} - -func TestPrint(t *testing.T) { - Display("v1", 1, "v2", 2, "v3", 3) -} - -func TestPrintPoint(t *testing.T) { - var v1 = new(mytype) - var v2 = new(mytype) - - v1.prev = nil - v1.next = v2 - - v2.prev = v1 - v2.next = nil - - Display("v1", v1, "v2", v2) -} - -func TestPrintString(t *testing.T) { - str := GetDisplayString("v1", 1, "v2", 2) - println(str) -} diff --git a/utils/file.go b/utils/file.go deleted file mode 100644 index 6090eb17..00000000 --- a/utils/file.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bufio" - "errors" - "io" - "os" - "path/filepath" - "regexp" -) - -// SelfPath gets compiled executable file absolute path -func SelfPath() string { - path, _ := filepath.Abs(os.Args[0]) - return path -} - -// SelfDir gets compiled executable file directory -func SelfDir() string { - return filepath.Dir(SelfPath()) -} - -// FileExists reports whether the named file or directory exists. -func FileExists(name string) bool { - if _, err := os.Stat(name); err != nil { - if os.IsNotExist(err) { - return false - } - } - return true -} - -// SearchFile Search a file in paths. -// this is often used in search config file in /etc ~/ -func SearchFile(filename string, paths ...string) (fullpath string, err error) { - for _, path := range paths { - if fullpath = filepath.Join(path, filename); FileExists(fullpath) { - return - } - } - err = errors.New(fullpath + " not found in paths") - return -} - -// GrepFile like command grep -E -// for example: GrepFile(`^hello`, "hello.txt") -// \n is striped while read -func GrepFile(patten string, filename string) (lines []string, err error) { - re, err := regexp.Compile(patten) - if err != nil { - return - } - - fd, err := os.Open(filename) - if err != nil { - return - } - lines = make([]string, 0) - reader := bufio.NewReader(fd) - prefix := "" - var isLongLine bool - for { - byteLine, isPrefix, er := reader.ReadLine() - if er != nil && er != io.EOF { - return nil, er - } - if er == io.EOF { - break - } - line := string(byteLine) - if isPrefix { - prefix += line - continue - } else { - isLongLine = true - } - - line = prefix + line - if isLongLine { - prefix = "" - } - if re.MatchString(line) { - lines = append(lines, line) - } - } - return lines, nil -} diff --git a/utils/file_test.go b/utils/file_test.go deleted file mode 100644 index 84443e20..00000000 --- a/utils/file_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "path/filepath" - "reflect" - "testing" - - "github.com/stretchr/testify/assert" -) - -var noExistedFile = "/tmp/not_existed_file" - -func TestSelfPath(t *testing.T) { - path := SelfPath() - if path == "" { - t.Error("path cannot be empty") - } - t.Logf("SelfPath: %s", path) -} - -func TestSelfDir(t *testing.T) { - dir := SelfDir() - t.Logf("SelfDir: %s", dir) -} - -func TestFileExists(t *testing.T) { - if !FileExists("./file.go") { - t.Errorf("./file.go should exists, but it didn't") - } - - if FileExists(noExistedFile) { - t.Errorf("Weird, how could this file exists: %s", noExistedFile) - } -} - -func TestSearchFile(t *testing.T) { - path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) - if err != nil { - t.Error(err) - } - t.Log(path) - - _, err = SearchFile(noExistedFile, ".") - if err == nil { - t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) - } -} - -func TestGrepFile(t *testing.T) { - _, err := GrepFile("", noExistedFile) - if err == nil { - t.Error("expect file-not-existed error, but got nothing") - } - - path := filepath.Join(".", "testdata", "grepe.test") - lines, err := GrepFile(`^\s*[^#]+`, path) - assert.Nil(t, err) - - if !reflect.DeepEqual(lines, []string{"hello", "world"}) { - t.Errorf("expect [hello world], but receive %v", lines) - } -} diff --git a/utils/mail.go b/utils/mail.go deleted file mode 100644 index 80a366ca..00000000 --- a/utils/mail.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "mime" - "mime/multipart" - "net/mail" - "net/smtp" - "net/textproto" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" -) - -const ( - maxLineLength = 76 - - upperhex = "0123456789ABCDEF" -) - -// Email is the type used for email messages -type Email struct { - Auth smtp.Auth - Identity string `json:"identity"` - Username string `json:"username"` - Password string `json:"password"` - Host string `json:"host"` - Port int `json:"port"` - From string `json:"from"` - To []string - Bcc []string - Cc []string - Subject string - Text string // Plaintext message (optional) - HTML string // Html message (optional) - Headers textproto.MIMEHeader - Attachments []*Attachment - ReadReceipt []string -} - -// Attachment is a struct representing an email attachment. -// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question -type Attachment struct { - Filename string - Header textproto.MIMEHeader - Content []byte -} - -// NewEMail create new Email struct with config json. -// config json is followed from Email struct fields. -func NewEMail(config string) *Email { - e := new(Email) - e.Headers = textproto.MIMEHeader{} - err := json.Unmarshal([]byte(config), e) - if err != nil { - return nil - } - return e -} - -// Bytes Make all send information to byte -func (e *Email) Bytes() ([]byte, error) { - buff := &bytes.Buffer{} - w := multipart.NewWriter(buff) - // Set the appropriate headers (overwriting any conflicts) - // Leave out Bcc (only included in envelope headers) - e.Headers.Set("To", strings.Join(e.To, ",")) - if e.Cc != nil { - e.Headers.Set("Cc", strings.Join(e.Cc, ",")) - } - e.Headers.Set("From", e.From) - e.Headers.Set("Subject", e.Subject) - if len(e.ReadReceipt) != 0 { - e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) - } - e.Headers.Set("MIME-Version", "1.0") - - // Write the envelope headers (including any custom headers) - if err := headerToBytes(buff, e.Headers); err != nil { - return nil, fmt.Errorf("Failed to render message headers: %s", err) - } - - e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) - fmt.Fprintf(buff, "%s:", "Content-Type") - fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) - - // Start the multipart/mixed part - fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) - header := textproto.MIMEHeader{} - // Check to see if there is a Text or HTML field - if e.Text != "" || e.HTML != "" { - subWriter := multipart.NewWriter(buff) - // Create the multipart alternative part - header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) - // Write the header - if err := headerToBytes(buff, header); err != nil { - return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) - } - // Create the body sections - if e.Text != "" { - header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) - header.Set("Content-Transfer-Encoding", "quoted-printable") - if _, err := subWriter.CreatePart(header); err != nil { - return nil, err - } - // Write the text - if err := quotePrintEncode(buff, e.Text); err != nil { - return nil, err - } - } - if e.HTML != "" { - header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) - header.Set("Content-Transfer-Encoding", "quoted-printable") - if _, err := subWriter.CreatePart(header); err != nil { - return nil, err - } - // Write the text - if err := quotePrintEncode(buff, e.HTML); err != nil { - return nil, err - } - } - if err := subWriter.Close(); err != nil { - return nil, err - } - } - // Create attachment part, if necessary - for _, a := range e.Attachments { - ap, err := w.CreatePart(a.Header) - if err != nil { - return nil, err - } - // Write the base64Wrapped content to the part - base64Wrap(ap, a.Content) - } - if err := w.Close(); err != nil { - return nil, err - } - return buff.Bytes(), nil -} - -// AttachFile Add attach file to the send mail -func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { - if len(args) < 1 || len(args) > 2 { // change && to || - err = errors.New("Must specify a file name and number of parameters can not exceed at least two") - return - } - filename := args[0] - id := "" - if len(args) > 1 { - id = args[1] - } - f, err := os.Open(filename) - if err != nil { - return - } - defer f.Close() - ct := mime.TypeByExtension(filepath.Ext(filename)) - basename := path.Base(filename) - return e.Attach(f, basename, ct, id) -} - -// Attach is used to attach content from an io.Reader to the email. -// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. -func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { - if len(args) < 1 || len(args) > 2 { // change && to || - err = errors.New("Must specify the file type and number of parameters can not exceed at least two") - return - } - c := args[0] //Content-Type - id := "" - if len(args) > 1 { - id = args[1] //Content-ID - } - var buffer bytes.Buffer - if _, err = io.Copy(&buffer, r); err != nil { - return - } - at := &Attachment{ - Filename: filename, - Header: textproto.MIMEHeader{}, - Content: buffer.Bytes(), - } - // Get the Content-Type to be used in the MIMEHeader - if c != "" { - at.Header.Set("Content-Type", c) - } else { - // If the Content-Type is blank, set the Content-Type to "application/octet-stream" - at.Header.Set("Content-Type", "application/octet-stream") - } - if id != "" { - at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) - at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) - } else { - at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) - } - at.Header.Set("Content-Transfer-Encoding", "base64") - e.Attachments = append(e.Attachments, at) - return at, nil -} - -// Send will send out the mail -func (e *Email) Send() error { - if e.Auth == nil { - e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) - } - // Merge the To, Cc, and Bcc fields - to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) - to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) - // Check to make sure there is at least one recipient and one "From" address - if len(to) == 0 { - return errors.New("Must specify at least one To address") - } - - // Use the username if no From is provided - if len(e.From) == 0 { - e.From = e.Username - } - - from, err := mail.ParseAddress(e.From) - if err != nil { - return err - } - - // use mail's RFC 2047 to encode any string - e.Subject = qEncode("utf-8", e.Subject) - - raw, err := e.Bytes() - if err != nil { - return err - } - return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) -} - -// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) -func quotePrintEncode(w io.Writer, s string) error { - var buf [3]byte - mc := 0 - for i := 0; i < len(s); i++ { - c := s[i] - // We're assuming Unix style text formats as input (LF line break), and - // quoted-printble uses CRLF line breaks. (Literal CRs will become - // "=0D", but probably shouldn't be there to begin with!) - if c == '\n' { - io.WriteString(w, "\r\n") - mc = 0 - continue - } - - var nextOut []byte - if isPrintable(c) { - nextOut = append(buf[:0], c) - } else { - nextOut = buf[:] - qpEscape(nextOut, c) - } - - // Add a soft line break if the next (encoded) byte would push this line - // to or past the limit. - if mc+len(nextOut) >= maxLineLength { - if _, err := io.WriteString(w, "=\r\n"); err != nil { - return err - } - mc = 0 - } - - if _, err := w.Write(nextOut); err != nil { - return err - } - mc += len(nextOut) - } - // No trailing end-of-line?? Soft line break, then. TODO: is this sane? - if mc > 0 { - io.WriteString(w, "=\r\n") - } - return nil -} - -// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise -func isPrintable(c byte) bool { - return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') -} - -// qpEscape is a helper function for quotePrintEncode which escapes a -// non-printable byte. Expects len(dest) == 3. -func qpEscape(dest []byte, c byte) { - const nums = "0123456789ABCDEF" - dest[0] = '=' - dest[1] = nums[(c&0xf0)>>4] - dest[2] = nums[(c & 0xf)] -} - -// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer -func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { - for k, v := range t { - // Write the header key - _, err := fmt.Fprintf(w, "%s:", k) - if err != nil { - return err - } - // Write each value in the header - for _, c := range v { - _, err := fmt.Fprintf(w, " %s\r\n", c) - if err != nil { - return err - } - } - } - return nil -} - -// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) -// The output is then written to the specified io.Writer -func base64Wrap(w io.Writer, b []byte) { - // 57 raw bytes per 76-byte base64 line. - const maxRaw = 57 - // Buffer for each line, including trailing CRLF. - var buffer [maxLineLength + len("\r\n")]byte - copy(buffer[maxLineLength:], "\r\n") - // Process raw chunks until there's no longer enough to fill a line. - for len(b) >= maxRaw { - base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) - w.Write(buffer[:]) - b = b[maxRaw:] - } - // Handle the last chunk of bytes. - if len(b) > 0 { - out := buffer[:base64.StdEncoding.EncodedLen(len(b))] - base64.StdEncoding.Encode(out, b) - out = append(out, "\r\n"...) - w.Write(out) - } -} - -// Encode returns the encoded-word form of s. If s is ASCII without special -// characters, it is returned unchanged. The provided charset is the IANA -// charset name of s. It is case insensitive. -// RFC 2047 encoded-word -func qEncode(charset, s string) string { - if !needsEncoding(s) { - return s - } - return encodeWord(charset, s) -} - -func needsEncoding(s string) bool { - for _, b := range s { - if (b < ' ' || b > '~') && b != '\t' { - return true - } - } - return false -} - -// encodeWord encodes a string into an encoded-word. -func encodeWord(charset, s string) string { - buf := getBuffer() - - buf.WriteString("=?") - buf.WriteString(charset) - buf.WriteByte('?') - buf.WriteByte('q') - buf.WriteByte('?') - - enc := make([]byte, 3) - for i := 0; i < len(s); i++ { - b := s[i] - switch { - case b == ' ': - buf.WriteByte('_') - case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_': - buf.WriteByte(b) - default: - enc[0] = '=' - enc[1] = upperhex[b>>4] - enc[2] = upperhex[b&0x0f] - buf.Write(enc) - } - } - buf.WriteString("?=") - - es := buf.String() - putBuffer(buf) - return es -} - -var bufPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} - -func getBuffer() *bytes.Buffer { - return bufPool.Get().(*bytes.Buffer) -} - -func putBuffer(buf *bytes.Buffer) { - if buf.Len() > 1024 { - return - } - buf.Reset() - bufPool.Put(buf) -} diff --git a/utils/mail_test.go b/utils/mail_test.go deleted file mode 100644 index c38356a2..00000000 --- a/utils/mail_test.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -func TestMail(t *testing.T) { - config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` - mail := NewEMail(config) - if mail.Username != "astaxie@gmail.com" { - t.Fatal("email parse get username error") - } - if mail.Password != "astaxie" { - t.Fatal("email parse get password error") - } - if mail.Host != "smtp.gmail.com" { - t.Fatal("email parse get host error") - } - if mail.Port != 587 { - t.Fatal("email parse get port error") - } - mail.To = []string{"xiemengjun@gmail.com"} - mail.From = "astaxie@gmail.com" - mail.Subject = "hi, just from beego!" - mail.Text = "Text Body is, of course, supported!" - mail.HTML = "

Fancy Html is supported, too!

" - mail.AttachFile("/Users/astaxie/github/beego/beego.go") - mail.Send() -} diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go deleted file mode 100644 index 2f022d0c..00000000 --- a/utils/pagination/controller.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "github.com/astaxie/beego/context" -) - -// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). -func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { - paginator = NewPaginator(context.Request, per, nums) - context.Input.SetData("paginator", &paginator) - return -} diff --git a/utils/pagination/doc.go b/utils/pagination/doc.go deleted file mode 100644 index 9abc6d78..00000000 --- a/utils/pagination/doc.go +++ /dev/null @@ -1,58 +0,0 @@ -/* -Package pagination provides utilities to setup a paginator within the -context of a http request. - -Usage - -In your beego.Controller: - - package controllers - - import "github.com/astaxie/beego/utils/pagination" - - type PostsController struct { - beego.Controller - } - - func (this *PostsController) ListAllPosts() { - // sets this.Data["paginator"] with the current offset (from the url query param) - postsPerPage := 20 - paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) - - // fetch the next 20 posts - this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) - } - - -In your view templates: - - {{if .paginator.HasPages}} - - {{end}} - -See also - -http://beego.me/docs/mvc/view/page.md - -*/ -package pagination diff --git a/utils/pagination/paginator.go b/utils/pagination/paginator.go deleted file mode 100644 index c6db31e0..00000000 --- a/utils/pagination/paginator.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "math" - "net/http" - "net/url" - "strconv" -) - -// Paginator within the state of a http request. -type Paginator struct { - Request *http.Request - PerPageNums int - MaxPages int - - nums int64 - pageRange []int - pageNums int - page int -} - -// PageNums Returns the total number of pages. -func (p *Paginator) PageNums() int { - if p.pageNums != 0 { - return p.pageNums - } - pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) - if p.MaxPages > 0 { - pageNums = math.Min(pageNums, float64(p.MaxPages)) - } - p.pageNums = int(pageNums) - return p.pageNums -} - -// Nums Returns the total number of items (e.g. from doing SQL count). -func (p *Paginator) Nums() int64 { - return p.nums -} - -// SetNums Sets the total number of items. -func (p *Paginator) SetNums(nums interface{}) { - p.nums, _ = toInt64(nums) -} - -// Page Returns the current page. -func (p *Paginator) Page() int { - if p.page != 0 { - return p.page - } - if p.Request.Form == nil { - p.Request.ParseForm() - } - p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) - if p.page > p.PageNums() { - p.page = p.PageNums() - } - if p.page <= 0 { - p.page = 1 - } - return p.page -} - -// Pages Returns a list of all pages. -// -// Usage (in a view template): -// -// {{range $index, $page := .paginator.Pages}} -// -// {{$page}} -// -// {{end}} -func (p *Paginator) Pages() []int { - if p.pageRange == nil && p.nums > 0 { - var pages []int - pageNums := p.PageNums() - page := p.Page() - switch { - case page >= pageNums-4 && pageNums > 9: - start := pageNums - 9 + 1 - pages = make([]int, 9) - for i := range pages { - pages[i] = start + i - } - case page >= 5 && pageNums > 9: - start := page - 5 + 1 - pages = make([]int, int(math.Min(9, float64(page+4+1)))) - for i := range pages { - pages[i] = start + i - } - default: - pages = make([]int, int(math.Min(9, float64(pageNums)))) - for i := range pages { - pages[i] = i + 1 - } - } - p.pageRange = pages - } - return p.pageRange -} - -// PageLink Returns URL for a given page index. -func (p *Paginator) PageLink(page int) string { - link, _ := url.ParseRequestURI(p.Request.URL.String()) - values := link.Query() - if page == 1 { - values.Del("p") - } else { - values.Set("p", strconv.Itoa(page)) - } - link.RawQuery = values.Encode() - return link.String() -} - -// PageLinkPrev Returns URL to the previous page. -func (p *Paginator) PageLinkPrev() (link string) { - if p.HasPrev() { - link = p.PageLink(p.Page() - 1) - } - return -} - -// PageLinkNext Returns URL to the next page. -func (p *Paginator) PageLinkNext() (link string) { - if p.HasNext() { - link = p.PageLink(p.Page() + 1) - } - return -} - -// PageLinkFirst Returns URL to the first page. -func (p *Paginator) PageLinkFirst() (link string) { - return p.PageLink(1) -} - -// PageLinkLast Returns URL to the last page. -func (p *Paginator) PageLinkLast() (link string) { - return p.PageLink(p.PageNums()) -} - -// HasPrev Returns true if the current page has a predecessor. -func (p *Paginator) HasPrev() bool { - return p.Page() > 1 -} - -// HasNext Returns true if the current page has a successor. -func (p *Paginator) HasNext() bool { - return p.Page() < p.PageNums() -} - -// IsActive Returns true if the given page index points to the current page. -func (p *Paginator) IsActive(page int) bool { - return p.Page() == page -} - -// Offset Returns the current offset. -func (p *Paginator) Offset() int { - return (p.Page() - 1) * p.PerPageNums -} - -// HasPages Returns true if there is more than one page. -func (p *Paginator) HasPages() bool { - return p.PageNums() > 1 -} - -// NewPaginator Instantiates a paginator struct for the current http request. -func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { - p := Paginator{} - p.Request = req - if per <= 0 { - per = 10 - } - p.PerPageNums = per - p.SetNums(nums) - return &p -} diff --git a/utils/pagination/utils.go b/utils/pagination/utils.go deleted file mode 100644 index 686e68b0..00000000 --- a/utils/pagination/utils.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "fmt" - "reflect" -) - -// ToInt64 convert any numeric value to int64 -func toInt64(value interface{}) (d int64, err error) { - val := reflect.ValueOf(value) - switch value.(type) { - case int, int8, int16, int32, int64: - d = val.Int() - case uint, uint8, uint16, uint32, uint64: - d = int64(val.Uint()) - default: - err = fmt.Errorf("ToInt64 need numeric not `%T`", value) - } - return -} diff --git a/utils/rand.go b/utils/rand.go deleted file mode 100644 index 344d1cd5..00000000 --- a/utils/rand.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "crypto/rand" - r "math/rand" - "time" -) - -var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) - -// RandomCreateBytes generate random []byte by specify chars. -func RandomCreateBytes(n int, alphabets ...byte) []byte { - if len(alphabets) == 0 { - alphabets = alphaNum - } - var bytes = make([]byte, n) - var randBy bool - if num, err := rand.Read(bytes); num != n || err != nil { - r.Seed(time.Now().UnixNano()) - randBy = true - } - for i, b := range bytes { - if randBy { - bytes[i] = alphabets[r.Intn(len(alphabets))] - } else { - bytes[i] = alphabets[b%byte(len(alphabets))] - } - } - return bytes -} diff --git a/utils/rand_test.go b/utils/rand_test.go deleted file mode 100644 index 6c238b5e..00000000 --- a/utils/rand_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -func TestRand_01(t *testing.T) { - bs0 := RandomCreateBytes(16) - bs1 := RandomCreateBytes(16) - - t.Log(string(bs0), string(bs1)) - if string(bs0) == string(bs1) { - t.FailNow() - } - - bs0 = RandomCreateBytes(4, []byte(`a`)...) - - if string(bs0) != "aaaa" { - t.FailNow() - } -} diff --git a/utils/safemap.go b/utils/safemap.go deleted file mode 100644 index 1793030a..00000000 --- a/utils/safemap.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "sync" -) - -// BeeMap is a map with lock -type BeeMap struct { - lock *sync.RWMutex - bm map[interface{}]interface{} -} - -// NewBeeMap return new safemap -func NewBeeMap() *BeeMap { - return &BeeMap{ - lock: new(sync.RWMutex), - bm: make(map[interface{}]interface{}), - } -} - -// Get from maps return the k's value -func (m *BeeMap) Get(k interface{}) interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - if val, ok := m.bm[k]; ok { - return val - } - return nil -} - -// Set Maps the given key and value. Returns false -// if the key is already in the map and changes nothing. -func (m *BeeMap) Set(k interface{}, v interface{}) bool { - m.lock.Lock() - defer m.lock.Unlock() - if val, ok := m.bm[k]; !ok { - m.bm[k] = v - } else if val != v { - m.bm[k] = v - } else { - return false - } - return true -} - -// Check Returns true if k is exist in the map. -func (m *BeeMap) Check(k interface{}) bool { - m.lock.RLock() - defer m.lock.RUnlock() - _, ok := m.bm[k] - return ok -} - -// Delete the given key and value. -func (m *BeeMap) Delete(k interface{}) { - m.lock.Lock() - defer m.lock.Unlock() - delete(m.bm, k) -} - -// Items returns all items in safemap. -func (m *BeeMap) Items() map[interface{}]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - r := make(map[interface{}]interface{}) - for k, v := range m.bm { - r[k] = v - } - return r -} - -// Count returns the number of items within the map. -func (m *BeeMap) Count() int { - m.lock.RLock() - defer m.lock.RUnlock() - return len(m.bm) -} diff --git a/utils/safemap_test.go b/utils/safemap_test.go deleted file mode 100644 index 65085195..00000000 --- a/utils/safemap_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -var safeMap *BeeMap - -func TestNewBeeMap(t *testing.T) { - safeMap = NewBeeMap() - if safeMap == nil { - t.Fatal("expected to return non-nil BeeMap", "got", safeMap) - } -} - -func TestSet(t *testing.T) { - safeMap = NewBeeMap() - if ok := safeMap.Set("astaxie", 1); !ok { - t.Error("expected", true, "got", false) - } -} - -func TestReSet(t *testing.T) { - safeMap := NewBeeMap() - if ok := safeMap.Set("astaxie", 1); !ok { - t.Error("expected", true, "got", false) - } - // set diff value - if ok := safeMap.Set("astaxie", -1); !ok { - t.Error("expected", true, "got", false) - } - - // set same value - if ok := safeMap.Set("astaxie", -1); ok { - t.Error("expected", false, "got", true) - } -} - -func TestCheck(t *testing.T) { - if exists := safeMap.Check("astaxie"); !exists { - t.Error("expected", true, "got", false) - } -} - -func TestGet(t *testing.T) { - if val := safeMap.Get("astaxie"); val.(int) != 1 { - t.Error("expected value", 1, "got", val) - } -} - -func TestDelete(t *testing.T) { - safeMap.Delete("astaxie") - if exists := safeMap.Check("astaxie"); exists { - t.Error("expected element to be deleted") - } -} - -func TestItems(t *testing.T) { - safeMap := NewBeeMap() - safeMap.Set("astaxie", "hello") - for k, v := range safeMap.Items() { - key := k.(string) - value := v.(string) - if key != "astaxie" { - t.Error("expected the key should be astaxie") - } - if value != "hello" { - t.Error("expected the value should be hello") - } - } -} - -func TestCount(t *testing.T) { - if count := safeMap.Count(); count != 0 { - t.Error("expected count to be", 0, "got", count) - } -} diff --git a/utils/slice.go b/utils/slice.go deleted file mode 100644 index 8f2cef98..00000000 --- a/utils/slice.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "math/rand" - "time" -) - -type reducetype func(interface{}) interface{} -type filtertype func(interface{}) bool - -// InSlice checks given string in string slice or not. -func InSlice(v string, sl []string) bool { - for _, vv := range sl { - if vv == v { - return true - } - } - return false -} - -// InSliceIface checks given interface in interface slice. -func InSliceIface(v interface{}, sl []interface{}) bool { - for _, vv := range sl { - if vv == v { - return true - } - } - return false -} - -// SliceRandList generate an int slice from min to max. -func SliceRandList(min, max int) []int { - if max < min { - min, max = max, min - } - length := max - min + 1 - t0 := time.Now() - rand.Seed(int64(t0.Nanosecond())) - list := rand.Perm(length) - for index := range list { - list[index] += min - } - return list -} - -// SliceMerge merges interface slices to one slice. -func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { - c = append(slice1, slice2...) - return -} - -// SliceReduce generates a new slice after parsing every value by reduce function -func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { - for _, v := range slice { - dslice = append(dslice, a(v)) - } - return -} - -// SliceRand returns random one from slice. -func SliceRand(a []interface{}) (b interface{}) { - randnum := rand.Intn(len(a)) - b = a[randnum] - return -} - -// SliceSum sums all values in int64 slice. -func SliceSum(intslice []int64) (sum int64) { - for _, v := range intslice { - sum += v - } - return -} - -// SliceFilter generates a new slice after filter function. -func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { - for _, v := range slice { - if a(v) { - ftslice = append(ftslice, v) - } - } - return -} - -// SliceDiff returns diff slice of slice1 - slice2. -func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { - for _, v := range slice1 { - if !InSliceIface(v, slice2) { - diffslice = append(diffslice, v) - } - } - return -} - -// SliceIntersect returns slice that are present in all the slice1 and slice2. -func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { - for _, v := range slice1 { - if InSliceIface(v, slice2) { - diffslice = append(diffslice, v) - } - } - return -} - -// SliceChunk separates one slice to some sized slice. -func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { - if size >= len(slice) { - chunkslice = append(chunkslice, slice) - return - } - end := size - for i := 0; i <= (len(slice) - size); i += size { - chunkslice = append(chunkslice, slice[i:end]) - end += size - } - return -} - -// SliceRange generates a new slice from begin to end with step duration of int64 number. -func SliceRange(start, end, step int64) (intslice []int64) { - for i := start; i <= end; i += step { - intslice = append(intslice, i) - } - return -} - -// SlicePad prepends size number of val into slice. -func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { - if size <= len(slice) { - return slice - } - for i := 0; i < (size - len(slice)); i++ { - slice = append(slice, val) - } - return slice -} - -// SliceUnique cleans repeated values in slice. -func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { - for _, v := range slice { - if !InSliceIface(v, uniqueslice) { - uniqueslice = append(uniqueslice, v) - } - } - return -} - -// SliceShuffle shuffles a slice. -func SliceShuffle(slice []interface{}) []interface{} { - for i := 0; i < len(slice); i++ { - a := rand.Intn(len(slice)) - b := rand.Intn(len(slice)) - slice[a], slice[b] = slice[b], slice[a] - } - return slice -} diff --git a/utils/slice_test.go b/utils/slice_test.go deleted file mode 100644 index 142dec96..00000000 --- a/utils/slice_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "testing" -) - -func TestInSlice(t *testing.T) { - sl := []string{"A", "b"} - if !InSlice("A", sl) { - t.Error("should be true") - } - if InSlice("B", sl) { - t.Error("should be false") - } -} diff --git a/utils/testdata/grepe.test b/utils/testdata/grepe.test deleted file mode 100644 index 6c014c40..00000000 --- a/utils/testdata/grepe.test +++ /dev/null @@ -1,7 +0,0 @@ -# empty lines - - - -hello -# comment -world diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index 3874b803..00000000 --- a/utils/utils.go +++ /dev/null @@ -1,89 +0,0 @@ -package utils - -import ( - "os" - "path/filepath" - "regexp" - "runtime" - "strconv" - "strings" -) - -// GetGOPATHs returns all paths in GOPATH variable. -func GetGOPATHs() []string { - gopath := os.Getenv("GOPATH") - if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { - gopath = defaultGOPATH() - } - return filepath.SplitList(gopath) -} - -func compareGoVersion(a, b string) int { - reg := regexp.MustCompile("^\\d*") - - a = strings.TrimPrefix(a, "go") - b = strings.TrimPrefix(b, "go") - - versionsA := strings.Split(a, ".") - versionsB := strings.Split(b, ".") - - for i := 0; i < len(versionsA) && i < len(versionsB); i++ { - versionA := versionsA[i] - versionB := versionsB[i] - - vA, err := strconv.Atoi(versionA) - if err != nil { - str := reg.FindString(versionA) - if str != "" { - vA, _ = strconv.Atoi(str) - } else { - vA = -1 - } - } - - vB, err := strconv.Atoi(versionB) - if err != nil { - str := reg.FindString(versionB) - if str != "" { - vB, _ = strconv.Atoi(str) - } else { - vB = -1 - } - } - - if vA > vB { - // vA = 12, vB = 8 - return 1 - } else if vA < vB { - // vA = 6, vB = 8 - return -1 - } else if vA == -1 { - // vA = rc1, vB = rc3 - return strings.Compare(versionA, versionB) - } - - // vA = vB = 8 - continue - } - - if len(versionsA) > len(versionsB) { - return 1 - } else if len(versionsA) == len(versionsB) { - return 0 - } - - return -1 -} - -func defaultGOPATH() string { - env := "HOME" - if runtime.GOOS == "windows" { - env = "USERPROFILE" - } else if runtime.GOOS == "plan9" { - env = "home" - } - if home := os.Getenv(env); home != "" { - return filepath.Join(home, "go") - } - return "" -} diff --git a/utils/utils_test.go b/utils/utils_test.go deleted file mode 100644 index ced6f63f..00000000 --- a/utils/utils_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package utils - -import ( - "testing" -) - -func TestCompareGoVersion(t *testing.T) { - targetVersion := "go1.8" - if compareGoVersion("go1.12.4", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8.7", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8", targetVersion) != 0 { - t.Error("should be 0") - } - - if compareGoVersion("go1.7.6", targetVersion) != -1 { - t.Error("should be -1") - } - - if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8rc1", targetVersion) != 0 { - t.Error("should be 0") - } - - if compareGoVersion("go1.7rc1", targetVersion) != -1 { - t.Error("should be -1") - } -} diff --git a/validation/README.md b/validation/README.md deleted file mode 100644 index 43373e47..00000000 --- a/validation/README.md +++ /dev/null @@ -1,147 +0,0 @@ -validation -============== - -validation is a form validation for a data validation and error collecting using Go. - -## Installation and tests - -Install: - - go get github.com/astaxie/beego/validation - -Test: - - go test github.com/astaxie/beego/validation - -## Example - -Direct Use: - - import ( - "github.com/astaxie/beego/validation" - "log" - ) - - type User struct { - Name string - Age int - } - - func main() { - u := User{"man", 40} - valid := validation.Validation{} - valid.Required(u.Name, "name") - valid.MaxSize(u.Name, 15, "nameMax") - valid.Range(u.Age, 0, 140, "age") - if valid.HasErrors() { - // validation does not pass - // print invalid message - for _, err := range valid.Errors { - log.Println(err.Key, err.Message) - } - } - // or use like this - if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { - log.Println(v.Error.Key, v.Error.Message) - } - } - -Struct Tag Use: - - import ( - "github.com/astaxie/beego/validation" - ) - - // validation function follow with "valid" tag - // functions divide with ";" - // parameters in parentheses "()" and divide with "," - // Match function's pattern string must in "//" - type user struct { - Id int - Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - - func main() { - valid := validation.Validation{} - // ignore empty field valid - // see CanSkipFuncs - // valid := validation.Validation{RequiredFirst:true} - u := user{Name: "test", Age: 40} - b, err := valid.Valid(u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - } - } - -Use custom function: - - import ( - "github.com/astaxie/beego/validation" - ) - - type user struct { - Id int - Name string `valid:"Required;IsMe"` - Age int `valid:"Required;Range(1, 140)"` - } - - func IsMe(v *validation.Validation, obj interface{}, key string) { - name, ok:= obj.(string) - if !ok { - // wrong use case? - return - } - - if name != "me" { - // valid false - v.SetError("Name", "is not me!") - } - } - - func main() { - valid := validation.Validation{} - if err := validation.AddCustomFunc("IsMe", IsMe); err != nil { - // hadle error - } - u := user{Name: "test", Age: 40} - b, err := valid.Valid(u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - } - } - -Struct Tag Functions: - - Required - Min(min int) - Max(max int) - Range(min, max int) - MinSize(min int) - MaxSize(max int) - Length(length int) - Alpha - Numeric - AlphaNumeric - Match(pattern string) - AlphaDash - Email - IP - Base64 - Mobile - Tel - Phone - ZipCode - - -## LICENSE - -BSD License http://creativecommons.org/licenses/BSD/ diff --git a/validation/util.go b/validation/util.go deleted file mode 100644 index 918b206c..00000000 --- a/validation/util.go +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" -) - -const ( - // ValidTag struct tag - ValidTag = "valid" - - LabelTag = "label" - - wordsize = 32 << (^uint(0) >> 32 & 1) -) - -var ( - // key: function name - // value: the number of parameters - funcs = make(Funcs) - - // doesn't belong to validation functions - unFuncs = map[string]bool{ - "Clear": true, - "HasErrors": true, - "ErrorMap": true, - "Error": true, - "apply": true, - "Check": true, - "Valid": true, - "NoMatch": true, - } - // ErrInt64On32 show 32 bit platform not support int64 - ErrInt64On32 = fmt.Errorf("not support int64 on 32-bit platform") -) - -func init() { - v := &Validation{} - t := reflect.TypeOf(v) - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - if !unFuncs[m.Name] { - funcs[m.Name] = m.Func - } - } -} - -// CustomFunc is for custom validate function -type CustomFunc func(v *Validation, obj interface{}, key string) - -// AddCustomFunc Add a custom function to validation -// The name can not be: -// Clear -// HasErrors -// ErrorMap -// Error -// Check -// Valid -// NoMatch -// If the name is same with exists function, it will replace the origin valid function -func AddCustomFunc(name string, f CustomFunc) error { - if unFuncs[name] { - return fmt.Errorf("invalid function name: %s", name) - } - - funcs[name] = reflect.ValueOf(f) - return nil -} - -// ValidFunc Valid function type -type ValidFunc struct { - Name string - Params []interface{} -} - -// Funcs Validate function map -type Funcs map[string]reflect.Value - -// Call validate values with named type string -func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - if _, ok := f[name]; !ok { - err = fmt.Errorf("%s does not exist", name) - return - } - if len(params) != f[name].Type().NumIn() { - err = fmt.Errorf("The number of params is not adapted") - return - } - in := make([]reflect.Value, len(params)) - for k, param := range params { - in[k] = reflect.ValueOf(param) - } - result = f[name].Call(in) - return -} - -func isStruct(t reflect.Type) bool { - return t.Kind() == reflect.Struct -} - -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - -func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { - tag := f.Tag.Get(ValidTag) - label := f.Tag.Get(LabelTag) - if len(tag) == 0 { - return - } - if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil { - return - } - fs := strings.Split(tag, ";") - for _, vfunc := range fs { - var vf ValidFunc - if len(vfunc) == 0 { - continue - } - vf, err = parseFunc(vfunc, f.Name, label) - if err != nil { - return - } - vfs = append(vfs, vf) - } - return -} - -// Get Match function -// May be get NoMatch function in the future -func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { - tag = strings.TrimSpace(tag) - index := strings.Index(tag, "Match(/") - if index == -1 { - str = tag - return - } - end := strings.LastIndex(tag, "/)") - if end < index { - err = fmt.Errorf("invalid Match function") - return - } - reg, err := regexp.Compile(tag[index+len("Match(/") : end]) - if err != nil { - return - } - vfs = []ValidFunc{{"Match", []interface{}{reg, key + ".Match"}}} - str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):]) - return -} - -func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - - vfunc = strings.TrimSpace(vfunc) - start := strings.Index(vfunc, "(") - var num int - - // doesn't need parameter valid function - if start == -1 { - if num, err = numIn(vfunc); err != nil { - return - } - if num != 0 { - err = fmt.Errorf("%s require %d parameters", vfunc, num) - return - } - v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} - return - } - - end := strings.Index(vfunc, ")") - if end == -1 { - err = fmt.Errorf("invalid valid function") - return - } - - name := strings.TrimSpace(vfunc[:start]) - if num, err = numIn(name); err != nil { - return - } - - params := strings.Split(vfunc[start+1:end], ",") - // the num of param must be equal - if num != len(params) { - err = fmt.Errorf("%s require %d parameters", name, num) - return - } - - tParams, err := trim(name, key+"."+name+"."+label, params) - if err != nil { - return - } - v = ValidFunc{name, tParams} - return -} - -func numIn(name string) (num int, err error) { - fn, ok := funcs[name] - if !ok { - err = fmt.Errorf("doesn't exists %s valid function", name) - return - } - // sub *Validation obj and key - num = fn.Type().NumIn() - 3 - return -} - -func trim(name, key string, s []string) (ts []interface{}, err error) { - ts = make([]interface{}, len(s), len(s)+1) - fn, ok := funcs[name] - if !ok { - err = fmt.Errorf("doesn't exists %s valid function", name) - return - } - for i := 0; i < len(s); i++ { - var param interface{} - // skip *Validation and obj params - if param, err = parseParam(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil { - return - } - ts[i] = param - } - ts = append(ts, key) - return -} - -// modify the parameters's type to adapt the function input parameters' type -func parseParam(t reflect.Type, s string) (i interface{}, err error) { - switch t.Kind() { - case reflect.Int: - i, err = strconv.Atoi(s) - case reflect.Int64: - if wordsize == 32 { - return nil, ErrInt64On32 - } - i, err = strconv.ParseInt(s, 10, 64) - case reflect.Int32: - var v int64 - v, err = strconv.ParseInt(s, 10, 32) - if err == nil { - i = int32(v) - } - case reflect.Int16: - var v int64 - v, err = strconv.ParseInt(s, 10, 16) - if err == nil { - i = int16(v) - } - case reflect.Int8: - var v int64 - v, err = strconv.ParseInt(s, 10, 8) - if err == nil { - i = int8(v) - } - case reflect.String: - i = s - case reflect.Ptr: - if t.Elem().String() != "regexp.Regexp" { - err = fmt.Errorf("not support %s", t.Elem().String()) - return - } - i, err = regexp.Compile(s) - default: - err = fmt.Errorf("not support %s", t.Kind().String()) - } - return -} - -func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} { - return append([]interface{}{v, obj}, params...) -} diff --git a/validation/util_test.go b/validation/util_test.go deleted file mode 100644 index 58ca38db..00000000 --- a/validation/util_test.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "log" - "reflect" - "testing" -) - -type user struct { - ID int - Tag string `valid:"Maxx(aa)"` - Name string `valid:"Required;"` - Age int `valid:"Required; Range(1, 140)"` - match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` -} - -func TestGetValidFuncs(t *testing.T) { - u := user{Name: "test", Age: 1} - tf := reflect.TypeOf(u) - var vfs []ValidFunc - var err error - - f, _ := tf.FieldByName("ID") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 0 { - t.Fatal("should get none ValidFunc") - } - - f, _ = tf.FieldByName("Tag") - if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { - t.Fatal(err) - } - - f, _ = tf.FieldByName("Name") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 1 { - t.Fatal("should get 1 ValidFunc") - } - if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { - t.Error("Required funcs should be got") - } - - f, _ = tf.FieldByName("Age") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 2 { - t.Fatal("should get 2 ValidFunc") - } - if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { - t.Error("Required funcs should be got") - } - if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 { - t.Error("Range funcs should be got") - } - - f, _ = tf.FieldByName("match") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 3 { - t.Fatal("should get 3 ValidFunc but now is", len(vfs)) - } -} - -type User struct { - Name string `valid:"Required;MaxSize(5)" ` - Sex string `valid:"Required;" label:"sex_label"` - Age int `valid:"Required;Range(1, 140);" label:"age_label"` -} - -func TestValidation(t *testing.T) { - u := User{"man1238888456", "", 1140} - valid := Validation{} - b, err := valid.Valid(&u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - for _, err := range valid.Errors { - log.Println(err.Key, err.Message) - } - if len(valid.Errors) != 3 { - t.Error("must be has 3 error") - } - } else { - t.Error("must be has 3 error") - } -} - -func TestCall(t *testing.T) { - u := user{Name: "test", Age: 180} - tf := reflect.TypeOf(u) - var vfs []ValidFunc - var err error - f, _ := tf.FieldByName("Age") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - valid := &Validation{} - vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...) - if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil { - t.Fatal(err) - } - if len(valid.Errors) != 1 { - t.Error("age out of range should be has an error") - } -} diff --git a/validation/validation.go b/validation/validation.go deleted file mode 100644 index 190e0f0e..00000000 --- a/validation/validation.go +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package validation for validations -// -// import ( -// "github.com/astaxie/beego/validation" -// "log" -// ) -// -// type User struct { -// Name string -// Age int -// } -// -// func main() { -// u := User{"man", 40} -// valid := validation.Validation{} -// valid.Required(u.Name, "name") -// valid.MaxSize(u.Name, 15, "nameMax") -// valid.Range(u.Age, 0, 140, "age") -// if valid.HasErrors() { -// // validation does not pass -// // print invalid message -// for _, err := range valid.Errors { -// log.Println(err.Key, err.Message) -// } -// } -// // or use like this -// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { -// log.Println(v.Error.Key, v.Error.Message) -// } -// } -// -// more info: http://beego.me/docs/mvc/controller/validation.md -package validation - -import ( - "fmt" - "reflect" - "regexp" - "strings" -) - -// ValidFormer valid interface -type ValidFormer interface { - Valid(*Validation) -} - -// Error show the error -type Error struct { - Message, Key, Name, Field, Tmpl string - Value interface{} - LimitValue interface{} -} - -// String Returns the Message. -func (e *Error) String() string { - if e == nil { - return "" - } - return e.Message -} - -// Implement Error interface. -// Return e.String() -func (e *Error) Error() string { return e.String() } - -// Result is returned from every validation method. -// It provides an indication of success, and a pointer to the Error (if any). -type Result struct { - Error *Error - Ok bool -} - -// Key Get Result by given key string. -func (r *Result) Key(key string) *Result { - if r.Error != nil { - r.Error.Key = key - } - return r -} - -// Message Set Result message by string or format string with args -func (r *Result) Message(message string, args ...interface{}) *Result { - if r.Error != nil { - if len(args) == 0 { - r.Error.Message = message - } else { - r.Error.Message = fmt.Sprintf(message, args...) - } - } - return r -} - -// A Validation context manages data validation and error messages. -type Validation struct { - // if this field set true, in struct tag valid - // if the struct field vale is empty - // it will skip those valid functions, see CanSkipFuncs - RequiredFirst bool - - Errors []*Error - ErrorsMap map[string][]*Error -} - -// Clear Clean all ValidationError. -func (v *Validation) Clear() { - v.Errors = []*Error{} - v.ErrorsMap = nil -} - -// HasErrors Has ValidationError nor not. -func (v *Validation) HasErrors() bool { - return len(v.Errors) > 0 -} - -// ErrorMap Return the errors mapped by key. -// If there are multiple validation errors associated with a single key, the -// first one "wins". (Typically the first validation will be the more basic). -func (v *Validation) ErrorMap() map[string][]*Error { - return v.ErrorsMap -} - -// Error Add an error to the validation context. -func (v *Validation) Error(message string, args ...interface{}) *Result { - result := (&Result{ - Ok: false, - Error: &Error{}, - }).Message(message, args...) - v.Errors = append(v.Errors, result.Error) - return result -} - -// Required Test that the argument is non-nil and non-empty (if string or list) -func (v *Validation) Required(obj interface{}, key string) *Result { - return v.apply(Required{key}, obj) -} - -// Min Test that the obj is greater than min if obj's type is int -func (v *Validation) Min(obj interface{}, min int, key string) *Result { - return v.apply(Min{min, key}, obj) -} - -// Max Test that the obj is less than max if obj's type is int -func (v *Validation) Max(obj interface{}, max int, key string) *Result { - return v.apply(Max{max, key}, obj) -} - -// Range Test that the obj is between mni and max if obj's type is int -func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { - return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) -} - -// MinSize Test that the obj is longer than min size if type is string or slice -func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { - return v.apply(MinSize{min, key}, obj) -} - -// MaxSize Test that the obj is shorter than max size if type is string or slice -func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { - return v.apply(MaxSize{max, key}, obj) -} - -// Length Test that the obj is same length to n if type is string or slice -func (v *Validation) Length(obj interface{}, n int, key string) *Result { - return v.apply(Length{n, key}, obj) -} - -// Alpha Test that the obj is [a-zA-Z] if type is string -func (v *Validation) Alpha(obj interface{}, key string) *Result { - return v.apply(Alpha{key}, obj) -} - -// Numeric Test that the obj is [0-9] if type is string -func (v *Validation) Numeric(obj interface{}, key string) *Result { - return v.apply(Numeric{key}, obj) -} - -// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string -func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { - return v.apply(AlphaNumeric{key}, obj) -} - -// Match Test that the obj matches regexp if type is string -func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { - return v.apply(Match{regex, key}, obj) -} - -// NoMatch Test that the obj doesn't match regexp if type is string -func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { - return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) -} - -// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string -func (v *Validation) AlphaDash(obj interface{}, key string) *Result { - return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) -} - -// Email Test that the obj is email address if type is string -func (v *Validation) Email(obj interface{}, key string) *Result { - return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) -} - -// IP Test that the obj is IP address if type is string -func (v *Validation) IP(obj interface{}, key string) *Result { - return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) -} - -// Base64 Test that the obj is base64 encoded if type is string -func (v *Validation) Base64(obj interface{}, key string) *Result { - return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) -} - -// Mobile Test that the obj is chinese mobile number if type is string -func (v *Validation) Mobile(obj interface{}, key string) *Result { - return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) -} - -// Tel Test that the obj is chinese telephone number if type is string -func (v *Validation) Tel(obj interface{}, key string) *Result { - return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) -} - -// Phone Test that the obj is chinese mobile or telephone number if type is string -func (v *Validation) Phone(obj interface{}, key string) *Result { - return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, - Tel{Match: Match{Regexp: telPattern}}, key}, obj) -} - -// ZipCode Test that the obj is chinese zip code if type is string -func (v *Validation) ZipCode(obj interface{}, key string) *Result { - return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) -} - -func (v *Validation) apply(chk Validator, obj interface{}) *Result { - if nil == obj { - if chk.IsSatisfied(obj) { - return &Result{Ok: true} - } - } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { - if reflect.ValueOf(obj).IsNil() { - if chk.IsSatisfied(nil) { - return &Result{Ok: true} - } - } else { - if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { - return &Result{Ok: true} - } - } - } else if chk.IsSatisfied(obj) { - return &Result{Ok: true} - } - - // Add the error to the validation context. - key := chk.GetKey() - Name := key - Field := "" - Label := "" - parts := strings.Split(key, ".") - if len(parts) == 3 { - Field = parts[0] - Name = parts[1] - Label = parts[2] - if len(Label) == 0 { - Label = Field - } - } - - err := &Error{ - Message: Label + " " + chk.DefaultMessage(), - Key: key, - Name: Name, - Field: Field, - Value: obj, - Tmpl: MessageTmpls[Name], - LimitValue: chk.GetLimitValue(), - } - v.setError(err) - - // Also return it in the result. - return &Result{ - Ok: false, - Error: err, - } -} - -// key must like aa.bb.cc or aa.bb. -// AddError adds independent error message for the provided key -func (v *Validation) AddError(key, message string) { - Name := key - Field := "" - - Label := "" - parts := strings.Split(key, ".") - if len(parts) == 3 { - Field = parts[0] - Name = parts[1] - Label = parts[2] - if len(Label) == 0 { - Label = Field - } - } - - err := &Error{ - Message: Label + " " + message, - Key: key, - Name: Name, - Field: Field, - } - v.setError(err) -} - -func (v *Validation) setError(err *Error) { - v.Errors = append(v.Errors, err) - if v.ErrorsMap == nil { - v.ErrorsMap = make(map[string][]*Error) - } - if _, ok := v.ErrorsMap[err.Field]; !ok { - v.ErrorsMap[err.Field] = []*Error{} - } - v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err) -} - -// SetError Set error message for one field in ValidationError -func (v *Validation) SetError(fieldName string, errMsg string) *Error { - err := &Error{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} - v.setError(err) - return err -} - -// Check Apply a group of validators to a field, in order, and return the -// ValidationResult from the first one that fails, or the last one that -// succeeds. -func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { - var result *Result - for _, check := range checks { - result = v.apply(check, obj) - if !result.Ok { - return result - } - } - return result -} - -// Valid Validate a struct. -// the obj parameter must be a struct or a struct pointer -func (v *Validation) Valid(obj interface{}) (b bool, err error) { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - switch { - case isStruct(objT): - case isStructPtr(objT): - objT = objT.Elem() - objV = objV.Elem() - default: - err = fmt.Errorf("%v must be a struct or a struct pointer", obj) - return - } - - for i := 0; i < objT.NumField(); i++ { - var vfs []ValidFunc - if vfs, err = getValidFuncs(objT.Field(i)); err != nil { - return - } - - var hasRequired bool - for _, vf := range vfs { - if vf.Name == "Required" { - hasRequired = true - } - - currentField := objV.Field(i).Interface() - if objV.Field(i).Kind() == reflect.Ptr { - if objV.Field(i).IsNil() { - currentField = "" - } else { - currentField = objV.Field(i).Elem().Interface() - } - } - - chk := Required{""}.IsSatisfied(currentField) - if !hasRequired && v.RequiredFirst && !chk { - if _, ok := CanSkipFuncs[vf.Name]; ok { - continue - } - } - - if _, err = funcs.Call(vf.Name, - mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil { - return - } - } - } - - if !v.HasErrors() { - if form, ok := obj.(ValidFormer); ok { - form.Valid(v) - } - } - - return !v.HasErrors(), nil -} - -// RecursiveValid Recursively validate a struct. -// Step1: Validate by v.Valid -// Step2: If pass on step1, then reflect obj's fields -// Step3: Do the Recursively validation to all struct or struct pointer fields -func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { - //Step 1: validate obj itself firstly - // fails if objc is not struct - pass, err := v.Valid(objc) - if err != nil || !pass { - return pass, err // Stop recursive validation - } - // Step 2: Validate struct's struct fields - objT := reflect.TypeOf(objc) - objV := reflect.ValueOf(objc) - - if isStructPtr(objT) { - objT = objT.Elem() - objV = objV.Elem() - } - - for i := 0; i < objT.NumField(); i++ { - - t := objT.Field(i).Type - - // Recursive applies to struct or pointer to structs fields - if isStruct(t) || isStructPtr(t) { - // Step 3: do the recursive validation - // Only valid the Public field recursively - if objV.Field(i).CanInterface() { - pass, err = v.RecursiveValid(objV.Field(i).Interface()) - } - } - } - return pass, err -} - -func (v *Validation) CanSkipAlso(skipFunc string) { - if _, ok := CanSkipFuncs[skipFunc]; !ok { - CanSkipFuncs[skipFunc] = struct{}{} - } -} diff --git a/validation/validation_test.go b/validation/validation_test.go deleted file mode 100644 index b4b5b1b6..00000000 --- a/validation/validation_test.go +++ /dev/null @@ -1,609 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "regexp" - "testing" - "time" -) - -func TestRequired(t *testing.T) { - valid := Validation{} - - if valid.Required(nil, "nil").Ok { - t.Error("nil object should be false") - } - if !valid.Required(true, "bool").Ok { - t.Error("Bool value should always return true") - } - if !valid.Required(false, "bool").Ok { - t.Error("Bool value should always return true") - } - if valid.Required("", "string").Ok { - t.Error("\"'\" string should be false") - } - if valid.Required(" ", "string").Ok { - t.Error("\" \" string should be false") // For #2361 - } - if valid.Required("\n", "string").Ok { - t.Error("new line string should be false") // For #2361 - } - if !valid.Required("astaxie", "string").Ok { - t.Error("string should be true") - } - if valid.Required(0, "zero").Ok { - t.Error("Integer should not be equal 0") - } - if !valid.Required(1, "int").Ok { - t.Error("Integer except 0 should be true") - } - if !valid.Required(time.Now(), "time").Ok { - t.Error("time should be true") - } - if valid.Required([]string{}, "emptySlice").Ok { - t.Error("empty slice should be false") - } - if !valid.Required([]interface{}{"ok"}, "slice").Ok { - t.Error("slice should be true") - } -} - -func TestMin(t *testing.T) { - valid := Validation{} - - if valid.Min(-1, 0, "min0").Ok { - t.Error("-1 is less than the minimum value of 0 should be false") - } - if !valid.Min(1, 0, "min0").Ok { - t.Error("1 is greater or equal than the minimum value of 0 should be true") - } -} - -func TestMax(t *testing.T) { - valid := Validation{} - - if valid.Max(1, 0, "max0").Ok { - t.Error("1 is greater than the minimum value of 0 should be false") - } - if !valid.Max(-1, 0, "max0").Ok { - t.Error("-1 is less or equal than the maximum value of 0 should be true") - } -} - -func TestRange(t *testing.T) { - valid := Validation{} - - if valid.Range(-1, 0, 1, "range0_1").Ok { - t.Error("-1 is between 0 and 1 should be false") - } - if !valid.Range(1, 0, 1, "range0_1").Ok { - t.Error("1 is between 0 and 1 should be true") - } -} - -func TestMinSize(t *testing.T) { - valid := Validation{} - - if valid.MinSize("", 1, "minSize1").Ok { - t.Error("the length of \"\" is less than the minimum value of 1 should be false") - } - if !valid.MinSize("ok", 1, "minSize1").Ok { - t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") - } - if valid.MinSize([]string{}, 1, "minSize1").Ok { - t.Error("the length of empty slice is less than the minimum value of 1 should be false") - } - if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { - t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") - } -} - -func TestMaxSize(t *testing.T) { - valid := Validation{} - - if valid.MaxSize("ok", 1, "maxSize1").Ok { - t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize("", 1, "maxSize1").Ok { - t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") - } - if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { - t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { - t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") - } -} - -func TestLength(t *testing.T) { - valid := Validation{} - - if valid.Length("", 1, "length1").Ok { - t.Error("the length of \"\" must equal 1 should be false") - } - if !valid.Length("1", 1, "length1").Ok { - t.Error("the length of \"1\" must equal 1 should be true") - } - if valid.Length([]string{}, 1, "length1").Ok { - t.Error("the length of empty slice must equal 1 should be false") - } - if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { - t.Error("the length of [\"ok\"] must equal 1 should be true") - } -} - -func TestAlpha(t *testing.T) { - valid := Validation{} - - if valid.Alpha("a,1-@ $", "alpha").Ok { - t.Error("\"a,1-@ $\" are valid alpha characters should be false") - } - if !valid.Alpha("abCD", "alpha").Ok { - t.Error("\"abCD\" are valid alpha characters should be true") - } -} - -func TestNumeric(t *testing.T) { - valid := Validation{} - - if valid.Numeric("a,1-@ $", "numeric").Ok { - t.Error("\"a,1-@ $\" are valid numeric characters should be false") - } - if !valid.Numeric("1234", "numeric").Ok { - t.Error("\"1234\" are valid numeric characters should be true") - } -} - -func TestAlphaNumeric(t *testing.T) { - valid := Validation{} - - if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { - t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") - } - if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { - t.Error("\"1234aB\" are valid alpha or numeric characters should be true") - } -} - -func TestMatch(t *testing.T) { - valid := Validation{} - - if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { - t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") - } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { - t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") - } -} - -func TestNoMatch(t *testing.T) { - valid := Validation{} - - if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { - t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") - } - if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { - t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") - } -} - -func TestAlphaDash(t *testing.T) { - valid := Validation{} - - if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { - t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") - } - if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { - t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") - } -} - -func TestEmail(t *testing.T) { - valid := Validation{} - - if valid.Email("not@a email", "email").Ok { - t.Error("\"not@a email\" is a valid email address should be false") - } - if !valid.Email("suchuangji@gmail.com", "email").Ok { - t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") - } - if valid.Email("@suchuangji@gmail.com", "email").Ok { - t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") - } - if valid.Email("suchuangji@gmail.com ok", "email").Ok { - t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") - } -} - -func TestIP(t *testing.T) { - valid := Validation{} - - if valid.IP("11.255.255.256", "IP").Ok { - t.Error("\"11.255.255.256\" is a valid ip address should be false") - } - if !valid.IP("01.11.11.11", "IP").Ok { - t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") - } -} - -func TestBase64(t *testing.T) { - valid := Validation{} - - if valid.Base64("suchuangji@gmail.com", "base64").Ok { - t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") - } - if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { - t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") - } -} - -func TestMobile(t *testing.T) { - valid := Validation{} - - validMobiles := []string{ - "19800008888", - "18800008888", - "18000008888", - "8618300008888", - "+8614700008888", - "17300008888", - "+8617100008888", - "8617500008888", - "8617400008888", - "16200008888", - "16500008888", - "16600008888", - "16700008888", - "13300008888", - "14900008888", - "15300008888", - "17300008888", - "17700008888", - "18000008888", - "18900008888", - "19100008888", - "19900008888", - "19300008888", - "13000008888", - "13100008888", - "13200008888", - "14500008888", - "15500008888", - "15600008888", - "16600008888", - "17100008888", - "17500008888", - "17600008888", - "18500008888", - "18600008888", - "13400008888", - "13500008888", - "13600008888", - "13700008888", - "13800008888", - "13900008888", - "14700008888", - "15000008888", - "15100008888", - "15200008888", - "15800008888", - "15900008888", - "17200008888", - "17800008888", - "18200008888", - "18300008888", - "18400008888", - "18700008888", - "18800008888", - "19800008888", - } - - for _, m := range validMobiles { - if !valid.Mobile(m, "mobile").Ok { - t.Error(m + " is a valid mobile phone number should be true") - } - } -} - -func TestTel(t *testing.T) { - valid := Validation{} - - if valid.Tel("222-00008888", "telephone").Ok { - t.Error("\"222-00008888\" is a valid telephone number should be false") - } - if !valid.Tel("022-70008888", "telephone").Ok { - t.Error("\"022-70008888\" is a valid telephone number should be true") - } - if !valid.Tel("02270008888", "telephone").Ok { - t.Error("\"02270008888\" is a valid telephone number should be true") - } - if !valid.Tel("70008888", "telephone").Ok { - t.Error("\"70008888\" is a valid telephone number should be true") - } -} - -func TestPhone(t *testing.T) { - valid := Validation{} - - if valid.Phone("222-00008888", "phone").Ok { - t.Error("\"222-00008888\" is a valid phone number should be false") - } - if !valid.Mobile("+8614700008888", "phone").Ok { - t.Error("\"+8614700008888\" is a valid phone number should be true") - } - if !valid.Tel("02270008888", "phone").Ok { - t.Error("\"02270008888\" is a valid phone number should be true") - } -} - -func TestZipCode(t *testing.T) { - valid := Validation{} - - if valid.ZipCode("", "zipcode").Ok { - t.Error("\"00008888\" is a valid zipcode should be false") - } - if !valid.ZipCode("536000", "zipcode").Ok { - t.Error("\"536000\" is a valid zipcode should be true") - } -} - -func TestValid(t *testing.T) { - type user struct { - ID int - Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - valid := Validation{} - - u := user{Name: "test@/test/;com", Age: 40} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Error("validation should be passed") - } - - uptr := &user{Name: "test", Age: 40} - valid.Clear() - b, err = valid.Valid(uptr) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Name.Match" { - t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) - } - - u = user{Name: "test@/test/;com", Age: 180} - valid.Clear() - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Age.Range." { - t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) - } -} - -func TestRecursiveValid(t *testing.T) { - type User struct { - ID int - Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - - type AnonymouseUser struct { - ID2 int - Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age2 int `valid:"Required;Range(1, 140)"` - } - - type Account struct { - Password string `valid:"Required"` - U User - AnonymouseUser - } - valid := Validation{} - - u := Account{Password: "abc123_", U: User{}} - b, err := valid.RecursiveValid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } -} - -func TestSkipValid(t *testing.T) { - type User struct { - ID int - - Email string `valid:"Email"` - ReqEmail string `valid:"Required;Email"` - - IP string `valid:"IP"` - ReqIP string `valid:"Required;IP"` - - Mobile string `valid:"Mobile"` - ReqMobile string `valid:"Required;Mobile"` - - Tel string `valid:"Tel"` - ReqTel string `valid:"Required;Tel"` - - Phone string `valid:"Phone"` - ReqPhone string `valid:"Required;Phone"` - - ZipCode string `valid:"ZipCode"` - ReqZipCode string `valid:"Required;ZipCode"` - } - - u := User{ - ReqEmail: "a@a.com", - ReqIP: "127.0.0.1", - ReqMobile: "18888888888", - ReqTel: "02088888888", - ReqPhone: "02088888888", - ReqZipCode: "510000", - } - - valid := Validation{} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } -} - -func TestPointer(t *testing.T) { - type User struct { - ID int - - Email *string `valid:"Email"` - ReqEmail *string `valid:"Required;Email"` - } - - u := User{ - ReqEmail: nil, - Email: nil, - } - - valid := Validation{} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - validEmail := "a@a.com" - u = User{ - ReqEmail: &validEmail, - Email: nil, - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } - - u = User{ - ReqEmail: &validEmail, - Email: nil, - } - - valid = Validation{} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - invalidEmail := "a@a" - u = User{ - ReqEmail: &validEmail, - Email: &invalidEmail, - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - u = User{ - ReqEmail: &validEmail, - Email: &invalidEmail, - } - - valid = Validation{} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } -} - -func TestCanSkipAlso(t *testing.T) { - type User struct { - ID int - - Email string `valid:"Email"` - ReqEmail string `valid:"Required;Email"` - MatchRange int `valid:"Range(10, 20)"` - } - - u := User{ - ReqEmail: "a@a.com", - Email: "", - MatchRange: 0, - } - - valid := Validation{RequiredFirst: true} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - valid = Validation{RequiredFirst: true} - valid.CanSkipAlso("Range") - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } - -} diff --git a/validation/validators.go b/validation/validators.go deleted file mode 100644 index 38b6f1aa..00000000 --- a/validation/validators.go +++ /dev/null @@ -1,738 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "fmt" - "github.com/astaxie/beego/logs" - "reflect" - "regexp" - "strings" - "sync" - "time" - "unicode/utf8" -) - -// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty -var CanSkipFuncs = map[string]struct{}{ - "Email": {}, - "IP": {}, - "Mobile": {}, - "Tel": {}, - "Phone": {}, - "ZipCode": {}, -} - -// MessageTmpls store commond validate template -var MessageTmpls = map[string]string{ - "Required": "Can not be empty", - "Min": "Minimum is %d", - "Max": "Maximum is %d", - "Range": "Range is %d to %d", - "MinSize": "Minimum size is %d", - "MaxSize": "Maximum size is %d", - "Length": "Required length is %d", - "Alpha": "Must be valid alpha characters", - "Numeric": "Must be valid numeric characters", - "AlphaNumeric": "Must be valid alpha or numeric characters", - "Match": "Must match %s", - "NoMatch": "Must not match %s", - "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", - "Email": "Must be a valid email address", - "IP": "Must be a valid ip address", - "Base64": "Must be valid base64 characters", - "Mobile": "Must be valid mobile number", - "Tel": "Must be valid telephone number", - "Phone": "Must be valid telephone or mobile phone number", - "ZipCode": "Must be valid zipcode", -} - -var once sync.Once - -// SetDefaultMessage set default messages -// if not set, the default messages are -// "Required": "Can not be empty", -// "Min": "Minimum is %d", -// "Max": "Maximum is %d", -// "Range": "Range is %d to %d", -// "MinSize": "Minimum size is %d", -// "MaxSize": "Maximum size is %d", -// "Length": "Required length is %d", -// "Alpha": "Must be valid alpha characters", -// "Numeric": "Must be valid numeric characters", -// "AlphaNumeric": "Must be valid alpha or numeric characters", -// "Match": "Must match %s", -// "NoMatch": "Must not match %s", -// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", -// "Email": "Must be a valid email address", -// "IP": "Must be a valid ip address", -// "Base64": "Must be valid base64 characters", -// "Mobile": "Must be valid mobile number", -// "Tel": "Must be valid telephone number", -// "Phone": "Must be valid telephone or mobile phone number", -// "ZipCode": "Must be valid zipcode", -func SetDefaultMessage(msg map[string]string) { - if len(msg) == 0 { - return - } - - once.Do(func() { - for name := range msg { - MessageTmpls[name] = msg[name] - } - }) - logs.Warn(`you must SetDefaultMessage at once`) -} - -// Validator interface -type Validator interface { - IsSatisfied(interface{}) bool - DefaultMessage() string - GetKey() string - GetLimitValue() interface{} -} - -// Required struct -type Required struct { - Key string -} - -// IsSatisfied judge whether obj has value -func (r Required) IsSatisfied(obj interface{}) bool { - if obj == nil { - return false - } - - if str, ok := obj.(string); ok { - return len(strings.TrimSpace(str)) > 0 - } - if _, ok := obj.(bool); ok { - return true - } - if i, ok := obj.(int); ok { - return i != 0 - } - if i, ok := obj.(uint); ok { - return i != 0 - } - if i, ok := obj.(int8); ok { - return i != 0 - } - if i, ok := obj.(uint8); ok { - return i != 0 - } - if i, ok := obj.(int16); ok { - return i != 0 - } - if i, ok := obj.(uint16); ok { - return i != 0 - } - if i, ok := obj.(uint32); ok { - return i != 0 - } - if i, ok := obj.(int32); ok { - return i != 0 - } - if i, ok := obj.(int64); ok { - return i != 0 - } - if i, ok := obj.(uint64); ok { - return i != 0 - } - if t, ok := obj.(time.Time); ok { - return !t.IsZero() - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() > 0 - } - return true -} - -// DefaultMessage return the default error message -func (r Required) DefaultMessage() string { - return MessageTmpls["Required"] -} - -// GetKey return the r.Key -func (r Required) GetKey() string { - return r.Key -} - -// GetLimitValue return nil now -func (r Required) GetLimitValue() interface{} { - return nil -} - -// Min check struct -type Min struct { - Min int - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (m Min) IsSatisfied(obj interface{}) bool { - var v int - switch obj.(type) { - case int64: - if wordsize == 32 { - return false - } - v = int(obj.(int64)) - case int: - v = obj.(int) - case int32: - v = int(obj.(int32)) - case int16: - v = int(obj.(int16)) - case int8: - v = int(obj.(int8)) - default: - return false - } - - return v >= m.Min -} - -// DefaultMessage return the default min error message -func (m Min) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Min"], m.Min) -} - -// GetKey return the m.Key -func (m Min) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value, Min -func (m Min) GetLimitValue() interface{} { - return m.Min -} - -// Max validate struct -type Max struct { - Max int - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (m Max) IsSatisfied(obj interface{}) bool { - var v int - switch obj.(type) { - case int64: - if wordsize == 32 { - return false - } - v = int(obj.(int64)) - case int: - v = obj.(int) - case int32: - v = int(obj.(int32)) - case int16: - v = int(obj.(int16)) - case int8: - v = int(obj.(int8)) - default: - return false - } - - return v <= m.Max -} - -// DefaultMessage return the default max error message -func (m Max) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Max"], m.Max) -} - -// GetKey return the m.Key -func (m Max) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value, Max -func (m Max) GetLimitValue() interface{} { - return m.Max -} - -// Range Requires an integer to be within Min, Max inclusive. -type Range struct { - Min - Max - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (r Range) IsSatisfied(obj interface{}) bool { - return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) -} - -// DefaultMessage return the default Range error message -func (r Range) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Range"], r.Min.Min, r.Max.Max) -} - -// GetKey return the m.Key -func (r Range) GetKey() string { - return r.Key -} - -// GetLimitValue return the limit value, Max -func (r Range) GetLimitValue() interface{} { - return []int{r.Min.Min, r.Max.Max} -} - -// MinSize Requires an array or string to be at least a given length. -type MinSize struct { - Min int - Key string -} - -// IsSatisfied judge whether obj is valid -func (m MinSize) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) >= m.Min - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() >= m.Min - } - return false -} - -// DefaultMessage return the default MinSize error message -func (m MinSize) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["MinSize"], m.Min) -} - -// GetKey return the m.Key -func (m MinSize) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m MinSize) GetLimitValue() interface{} { - return m.Min -} - -// MaxSize Requires an array or string to be at most a given length. -type MaxSize struct { - Max int - Key string -} - -// IsSatisfied judge whether obj is valid -func (m MaxSize) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) <= m.Max - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() <= m.Max - } - return false -} - -// DefaultMessage return the default MaxSize error message -func (m MaxSize) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["MaxSize"], m.Max) -} - -// GetKey return the m.Key -func (m MaxSize) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m MaxSize) GetLimitValue() interface{} { - return m.Max -} - -// Length Requires an array or string to be exactly a given length. -type Length struct { - N int - Key string -} - -// IsSatisfied judge whether obj is valid -func (l Length) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) == l.N - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() == l.N - } - return false -} - -// DefaultMessage return the default Length error message -func (l Length) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Length"], l.N) -} - -// GetKey return the m.Key -func (l Length) GetKey() string { - return l.Key -} - -// GetLimitValue return the limit value -func (l Length) GetLimitValue() interface{} { - return l.N -} - -// Alpha check the alpha -type Alpha struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (a Alpha) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if ('Z' < v || v < 'A') && ('z' < v || v < 'a') { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (a Alpha) DefaultMessage() string { - return MessageTmpls["Alpha"] -} - -// GetKey return the m.Key -func (a Alpha) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a Alpha) GetLimitValue() interface{} { - return nil -} - -// Numeric check number -type Numeric struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (n Numeric) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if '9' < v || v < '0' { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (n Numeric) DefaultMessage() string { - return MessageTmpls["Numeric"] -} - -// GetKey return the n.Key -func (n Numeric) GetKey() string { - return n.Key -} - -// GetLimitValue return the limit value -func (n Numeric) GetLimitValue() interface{} { - return nil -} - -// AlphaNumeric check alpha and number -type AlphaNumeric struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (a AlphaNumeric) DefaultMessage() string { - return MessageTmpls["AlphaNumeric"] -} - -// GetKey return the a.Key -func (a AlphaNumeric) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a AlphaNumeric) GetLimitValue() interface{} { - return nil -} - -// Match Requires a string to match a given regex. -type Match struct { - Regexp *regexp.Regexp - Key string -} - -// IsSatisfied judge whether obj is valid -func (m Match) IsSatisfied(obj interface{}) bool { - return m.Regexp.MatchString(fmt.Sprintf("%v", obj)) -} - -// DefaultMessage return the default Match error message -func (m Match) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Match"], m.Regexp.String()) -} - -// GetKey return the m.Key -func (m Match) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m Match) GetLimitValue() interface{} { - return m.Regexp.String() -} - -// NoMatch Requires a string to not match a given regex. -type NoMatch struct { - Match - Key string -} - -// IsSatisfied judge whether obj is valid -func (n NoMatch) IsSatisfied(obj interface{}) bool { - return !n.Match.IsSatisfied(obj) -} - -// DefaultMessage return the default NoMatch error message -func (n NoMatch) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["NoMatch"], n.Regexp.String()) -} - -// GetKey return the n.Key -func (n NoMatch) GetKey() string { - return n.Key -} - -// GetLimitValue return the limit value -func (n NoMatch) GetLimitValue() interface{} { - return n.Regexp.String() -} - -var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) - -// AlphaDash check not Alpha -type AlphaDash struct { - NoMatch - Key string -} - -// DefaultMessage return the default AlphaDash error message -func (a AlphaDash) DefaultMessage() string { - return MessageTmpls["AlphaDash"] -} - -// GetKey return the n.Key -func (a AlphaDash) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a AlphaDash) GetLimitValue() interface{} { - return nil -} - -var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) - -// Email check struct -type Email struct { - Match - Key string -} - -// DefaultMessage return the default Email error message -func (e Email) DefaultMessage() string { - return MessageTmpls["Email"] -} - -// GetKey return the n.Key -func (e Email) GetKey() string { - return e.Key -} - -// GetLimitValue return the limit value -func (e Email) GetLimitValue() interface{} { - return nil -} - -var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) - -// IP check struct -type IP struct { - Match - Key string -} - -// DefaultMessage return the default IP error message -func (i IP) DefaultMessage() string { - return MessageTmpls["IP"] -} - -// GetKey return the i.Key -func (i IP) GetKey() string { - return i.Key -} - -// GetLimitValue return the limit value -func (i IP) GetLimitValue() interface{} { - return nil -} - -var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) - -// Base64 check struct -type Base64 struct { - Match - Key string -} - -// DefaultMessage return the default Base64 error message -func (b Base64) DefaultMessage() string { - return MessageTmpls["Base64"] -} - -// GetKey return the b.Key -func (b Base64) GetKey() string { - return b.Key -} - -// GetLimitValue return the limit value -func (b Base64) GetLimitValue() interface{} { - return nil -} - -// just for chinese mobile phone number -var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) - -// Mobile check struct -type Mobile struct { - Match - Key string -} - -// DefaultMessage return the default Mobile error message -func (m Mobile) DefaultMessage() string { - return MessageTmpls["Mobile"] -} - -// GetKey return the m.Key -func (m Mobile) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m Mobile) GetLimitValue() interface{} { - return nil -} - -// just for chinese telephone number -var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) - -// Tel check telephone struct -type Tel struct { - Match - Key string -} - -// DefaultMessage return the default Tel error message -func (t Tel) DefaultMessage() string { - return MessageTmpls["Tel"] -} - -// GetKey return the t.Key -func (t Tel) GetKey() string { - return t.Key -} - -// GetLimitValue return the limit value -func (t Tel) GetLimitValue() interface{} { - return nil -} - -// Phone just for chinese telephone or mobile phone number -type Phone struct { - Mobile - Tel - Key string -} - -// IsSatisfied judge whether obj is valid -func (p Phone) IsSatisfied(obj interface{}) bool { - return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj) -} - -// DefaultMessage return the default Phone error message -func (p Phone) DefaultMessage() string { - return MessageTmpls["Phone"] -} - -// GetKey return the p.Key -func (p Phone) GetKey() string { - return p.Key -} - -// GetLimitValue return the limit value -func (p Phone) GetLimitValue() interface{} { - return nil -} - -// just for chinese zipcode -var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) - -// ZipCode check the zip struct -type ZipCode struct { - Match - Key string -} - -// DefaultMessage return the default Zip error message -func (z ZipCode) DefaultMessage() string { - return MessageTmpls["ZipCode"] -} - -// GetKey return the z.Key -func (z ZipCode) GetKey() string { - return z.Key -} - -// GetLimitValue return the limit value -func (z ZipCode) GetLimitValue() interface{} { - return nil -} From 4db256c9fbbc1ae44a49651e0e205876650ede0b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 18 Aug 2020 20:55:11 +0800 Subject: [PATCH 099/207] Add git hooks --- CONTRIBUTING.md | 16 ++++++++++++++++ scripts/gobuild.sh => build/gobuild-sample.sh | 0 {scripts => build}/report_build_info.sh | 0 githook/pre-commit | 7 +++++++ 4 files changed, 23 insertions(+) rename scripts/gobuild.sh => build/gobuild-sample.sh (100%) rename {scripts => build}/report_build_info.sh (100%) create mode 100755 githook/pre-commit diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 77adfb65..ee7e0b5a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,6 +7,22 @@ It is the work of hundreds of contributors. We appreciate your help! Here are instructions to get you started. They are probably not perfect, please let us know if anything feels wrong or incomplete. +## Prepare environment + +Firstly, install some tools. Execute those commands **outside** the project. Or those command will modify go.mod file. + +```shell script +go get -u golang.org/x/tools/cmd/goimports + +go get -u github.com/gordonklaus/ineffassign +``` + +And the go into project directory, run : +```shell script +cp ./githook/pre-commit ./.git/hooks/pre-commit +``` +This will add git hooks into .git/hooks. Or you can add it manually. + ## Contribution guidelines ### Pull requests diff --git a/scripts/gobuild.sh b/build/gobuild-sample.sh similarity index 100% rename from scripts/gobuild.sh rename to build/gobuild-sample.sh diff --git a/scripts/report_build_info.sh b/build/report_build_info.sh similarity index 100% rename from scripts/report_build_info.sh rename to build/report_build_info.sh diff --git a/githook/pre-commit b/githook/pre-commit new file mode 100755 index 00000000..594f1edd --- /dev/null +++ b/githook/pre-commit @@ -0,0 +1,7 @@ + +goimports -l pkg +goimports -l examples + +ineffassign . + +staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./pkg \ No newline at end of file From 7fe4eaef50e89a16d0f6e890e9330bdf73192c1d Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 18 Aug 2020 14:31:06 +0000 Subject: [PATCH 100/207] Refactor orm filter --- githook/pre-commit | 4 +- pkg/bean/context.go | 3 +- pkg/bean/doc.go | 2 +- pkg/bean/factory.go | 4 +- pkg/bean/metadata.go | 2 +- pkg/bean/tag_auto_wire_bean_factory.go | 2 +- pkg/bean/tag_auto_wire_bean_factory_test.go | 2 +- pkg/bean/time_type_adapter.go | 5 +- pkg/bean/time_type_adapter_test.go | 2 +- pkg/bean/type_adapter.go | 2 +- pkg/context/param/parsers_test.go | 8 +- pkg/controller_test.go | 3 +- pkg/orm/db.go | 9 +- pkg/orm/db_alias.go | 3 +- pkg/orm/db_alias_test.go | 3 +- pkg/orm/db_mysql.go | 2 +- pkg/orm/db_oracle.go | 5 +- pkg/orm/db_postgres.go | 1 - pkg/orm/db_sqlite.go | 4 +- pkg/orm/db_tables.go | 2 +- pkg/orm/do_nothing_orm.go | 6 +- pkg/orm/filter.go | 6 +- pkg/orm/filter/bean/default_value_filter.go | 6 +- .../filter/bean/default_value_filter_test.go | 2 +- pkg/orm/filter/opentracing/filter.go | 8 +- pkg/orm/filter/opentracing/filter_test.go | 3 +- pkg/orm/filter/prometheus/filter.go | 5 +- pkg/orm/filter/prometheus/filter_test.go | 3 +- pkg/orm/filter_orm_decorator.go | 258 +++++++++--------- pkg/orm/filter_orm_decorator_test.go | 71 +++-- pkg/orm/filter_test.go | 4 +- pkg/orm/hints/db_hints.go | 3 +- pkg/orm/hints/db_hints_test.go | 5 +- pkg/orm/invocation.go | 6 +- pkg/orm/models_test.go | 5 +- pkg/orm/orm.go | 5 +- pkg/orm/orm_queryset.go | 1 + pkg/orm/orm_raw.go | 3 +- pkg/orm/orm_test.go | 5 +- pkg/orm/types.go | 5 +- pkg/parser.go | 3 +- pkg/plugins/auth/basic.go | 2 +- pkg/plugins/authz/authz.go | 5 +- pkg/plugins/authz/authz_test.go | 9 +- pkg/plugins/cors/cors.go | 2 +- pkg/plugins/cors/cors_test.go | 2 +- pkg/session/redis/sess_redis.go | 2 +- pkg/session/redis_cluster/redis_cluster.go | 5 +- .../redis_sentinel/sess_redis_sentinel.go | 5 +- pkg/staticfile.go | 2 +- pkg/template_test.go | 2 +- pkg/utils/captcha/captcha.go | 2 +- pkg/validation/validators.go | 3 +- test/bindata.go | 3 +- 54 files changed, 269 insertions(+), 256 deletions(-) diff --git a/githook/pre-commit b/githook/pre-commit index 594f1edd..95b1009b 100755 --- a/githook/pre-commit +++ b/githook/pre-commit @@ -1,6 +1,6 @@ -goimports -l pkg -goimports -l examples +goimports -w -format-only pkg +goimports -w -format-only examples ineffassign . diff --git a/pkg/bean/context.go b/pkg/bean/context.go index 93261628..7cee2c7e 100644 --- a/pkg/bean/context.go +++ b/pkg/bean/context.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,5 +17,4 @@ package bean // ApplicationContext define for future // when we decide to support DI, IoC, this will be core API type ApplicationContext interface { - } diff --git a/pkg/bean/doc.go b/pkg/bean/doc.go index 212e8aaf..f806a081 100644 --- a/pkg/bean/doc.go +++ b/pkg/bean/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/bean/factory.go b/pkg/bean/factory.go index 698474c4..1097604c 100644 --- a/pkg/bean/factory.go +++ b/pkg/bean/factory.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,4 +22,4 @@ import ( type AutoWireBeanFactory interface { // AutoWire will wire the bean. AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error -} \ No newline at end of file +} diff --git a/pkg/bean/metadata.go b/pkg/bean/metadata.go index 8c423692..e2e34f55 100644 --- a/pkg/bean/metadata.go +++ b/pkg/bean/metadata.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/bean/tag_auto_wire_bean_factory.go b/pkg/bean/tag_auto_wire_bean_factory.go index ea8fd907..569ffb0d 100644 --- a/pkg/bean/tag_auto_wire_bean_factory.go +++ b/pkg/bean/tag_auto_wire_bean_factory.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/bean/tag_auto_wire_bean_factory_test.go b/pkg/bean/tag_auto_wire_bean_factory_test.go index 2d83c537..bcdada67 100644 --- a/pkg/bean/tag_auto_wire_bean_factory_test.go +++ b/pkg/bean/tag_auto_wire_bean_factory_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/bean/time_type_adapter.go b/pkg/bean/time_type_adapter.go index 846eb694..b0e99896 100644 --- a/pkg/bean/time_type_adapter.go +++ b/pkg/bean/time_type_adapter.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package bean import ( "context" "time" - ) // TimeTypeAdapter process the time.Time @@ -29,7 +28,7 @@ type TimeTypeAdapter struct { // and if the DftValue == now // time.Now() is returned func (t *TimeTypeAdapter) DefaultValue(ctx context.Context, dftValue string) (interface{}, error) { - if dftValue == "now"{ + if dftValue == "now" { return time.Now(), nil } return time.Parse(t.Layout, dftValue) diff --git a/pkg/bean/time_type_adapter_test.go b/pkg/bean/time_type_adapter_test.go index 9c097048..140ef5a6 100644 --- a/pkg/bean/time_type_adapter_test.go +++ b/pkg/bean/time_type_adapter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/bean/type_adapter.go b/pkg/bean/type_adapter.go index ba675b64..5869032d 100644 --- a/pkg/bean/type_adapter.go +++ b/pkg/bean/type_adapter.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/context/param/parsers_test.go b/pkg/context/param/parsers_test.go index 7065a28e..81a821f1 100644 --- a/pkg/context/param/parsers_test.go +++ b/pkg/context/param/parsers_test.go @@ -1,8 +1,10 @@ package param -import "testing" -import "reflect" -import "time" +import ( + "reflect" + "testing" + "time" +) type testDefinition struct { strValue string diff --git a/pkg/controller_test.go b/pkg/controller_test.go index e30f7211..97f1e964 100644 --- a/pkg/controller_test.go +++ b/pkg/controller_test.go @@ -21,9 +21,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/context" "os" "path/filepath" + + "github.com/astaxie/beego/pkg/context" ) func TestGetInt(t *testing.T) { diff --git a/pkg/orm/db.go b/pkg/orm/db.go index 0b6d8ac1..905c8189 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -18,10 +18,11 @@ import ( "database/sql" "errors" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "reflect" "strings" "time" + + "github.com/astaxie/beego/pkg/orm/hints" ) const ( @@ -490,7 +491,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable - }else{ + } else { return lastInsertId, nil } } @@ -598,7 +599,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable - }else{ + } else { return lastInsertId, nil } } @@ -1954,5 +1955,3 @@ func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes [] return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`)) } - - diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 9f12bf2a..0a53ad31 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" + "github.com/astaxie/beego/pkg/orm/hints" + lru "github.com/hashicorp/golang-lru" "github.com/astaxie/beego/pkg/common" diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 36e65fc8..4a561a27 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -15,10 +15,11 @@ package orm import ( - "github.com/astaxie/beego/pkg/orm/hints" "testing" "time" + "github.com/astaxie/beego/pkg/orm/hints" + "github.com/stretchr/testify/assert" ) diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go index efa5a50b..d934d842 100644 --- a/pkg/orm/db_mysql.go +++ b/pkg/orm/db_mysql.go @@ -169,7 +169,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable - }else{ + } else { return lastInsertId, nil } } diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index d384d33e..66246ec4 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -16,8 +16,9 @@ package orm import ( "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "strings" + + "github.com/astaxie/beego/pkg/orm/hints" ) // oracle operators. @@ -155,7 +156,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam if err != nil { DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable - }else{ + } else { return lastInsertId, nil } } diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go index cf1a3413..35471ddc 100644 --- a/pkg/orm/db_postgres.go +++ b/pkg/orm/db_postgres.go @@ -92,7 +92,6 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } - // postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go index 244aae7a..f9d379ce 100644 --- a/pkg/orm/db_sqlite.go +++ b/pkg/orm/db_sqlite.go @@ -17,10 +17,11 @@ package orm import ( "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "reflect" "strings" "time" + + "github.com/astaxie/beego/pkg/orm/hints" ) // sqlite operators. @@ -173,7 +174,6 @@ func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, inde } } - // create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go index d7e99639..5fd472d1 100644 --- a/pkg/orm/db_tables.go +++ b/pkg/orm/db_tables.go @@ -473,7 +473,7 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits } // getIndexSql generate index sql. -func (t *dbTables) getIndexSql(tableName string,useIndex int, indexes []string) (clause string) { +func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) { if len(indexes) == 0 { return } diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go index 357428f2..afc428f2 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/orm/do_nothing_orm.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" ) @@ -27,7 +28,6 @@ import ( var _ Ormer = new(DoNothingOrm) type DoNothingOrm struct { - } func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { @@ -54,11 +54,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, return false, 0, nil } -func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } -func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } diff --git a/pkg/orm/filter.go b/pkg/orm/filter.go index 03a30022..bc13c3fa 100644 --- a/pkg/orm/filter.go +++ b/pkg/orm/filter.go @@ -24,7 +24,11 @@ type FilterChain func(next Filter) Filter // Filter's behavior is a little big strange. // it's only be called when users call methods of Ormer -type Filter func(ctx context.Context, inv *Invocation) +// return value is an array. it's a little bit hard to understand, +// for example, the Ormer's Read method only return error +// so the filter processing this method should return an array whose first element is error +// and, Ormer's ReadOrCreateWithCtx return three values, so the Filter's result should contains three values +type Filter func(ctx context.Context, inv *Invocation) []interface{} var globalFilterChains = make([]FilterChain, 0, 4) diff --git a/pkg/orm/filter/bean/default_value_filter.go b/pkg/orm/filter/bean/default_value_filter.go index 80aef43d..b3ef7415 100644 --- a/pkg/orm/filter/bean/default_value_filter.go +++ b/pkg/orm/filter/bean/default_value_filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -76,7 +76,7 @@ func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter, } func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { - return func(ctx context.Context, inv *orm.Invocation) { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { switch inv.Method { case "Insert", "InsertWithCtx": d.handleInsert(ctx, inv) @@ -88,7 +88,7 @@ func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter d.handleInsertMulti(ctx, inv) break } - next(ctx, inv) + return next(ctx, inv) } } diff --git a/pkg/orm/filter/bean/default_value_filter_test.go b/pkg/orm/filter/bean/default_value_filter_test.go index b939698f..2c754a3e 100644 --- a/pkg/orm/filter/bean/default_value_filter_test.go +++ b/pkg/orm/filter/bean/default_value_filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter/opentracing/filter.go b/pkg/orm/filter/opentracing/filter.go index 405e39ea..0b6968b7 100644 --- a/pkg/orm/filter/opentracing/filter.go +++ b/pkg/orm/filter/opentracing/filter.go @@ -36,17 +36,17 @@ type FilterChainBuilder struct { } func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { - return func(ctx context.Context, inv *orm.Invocation) { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { operationName := builder.operationName(ctx, inv) if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" { - next(ctx, inv) - return + return next(ctx, inv) } span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) defer span.Finish() - next(spanCtx, inv) + res := next(spanCtx, inv) builder.buildSpan(span, spanCtx, inv) + return res } } diff --git a/pkg/orm/filter/opentracing/filter_test.go b/pkg/orm/filter/opentracing/filter_test.go index 7df12a92..8f6d4807 100644 --- a/pkg/orm/filter/opentracing/filter_test.go +++ b/pkg/orm/filter/opentracing/filter_test.go @@ -25,8 +25,9 @@ import ( ) func TestFilterChainBuilder_FilterChain(t *testing.T) { - next := func(ctx context.Context, inv *orm.Invocation) { + next := func(ctx context.Context, inv *orm.Invocation) []interface{} { inv.TxName = "Hello" + return []interface{}{} } builder := &FilterChainBuilder{ diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/orm/filter/prometheus/filter.go index 2e67d85c..fb2b473d 100644 --- a/pkg/orm/filter/prometheus/filter.go +++ b/pkg/orm/filter/prometheus/filter.go @@ -56,15 +56,16 @@ func NewFilterChainBuilder() *FilterChainBuilder { } func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { - return func(ctx context.Context, inv *orm.Invocation) { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { startTime := time.Now() - next(ctx, inv) + res := next(ctx, inv) endTime := time.Now() dur := (endTime.Sub(startTime)) / time.Millisecond // if the TPS is too large, here may be some problem // thinking about using goroutine pool go builder.report(ctx, inv, dur) + return res } } diff --git a/pkg/orm/filter/prometheus/filter_test.go b/pkg/orm/filter/prometheus/filter_test.go index 34766fb4..aad62c18 100644 --- a/pkg/orm/filter/prometheus/filter_test.go +++ b/pkg/orm/filter/prometheus/filter_test.go @@ -28,8 +28,9 @@ func TestFilterChainBuilder_FilterChain(t *testing.T) { builder := NewFilterChainBuilder() assert.NotNil(t, builder.summaryVec) - filter := builder.FilterChain(func(ctx context.Context, inv *orm.Invocation) { + filter := builder.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { inv.Method = "coming" + return []interface{}{} }) assert.NotNil(t, filter) diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go index e5b81472..97221a07 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/orm/filter_orm_decorator.go @@ -17,13 +17,14 @@ package orm import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/common" "reflect" "time" + + "github.com/astaxie/beego/pkg/common" ) const ( - TxNameKey = "TxName" + TxNameKey = "TxName" ) var _ Ormer = new(filterOrmDecorator) @@ -45,8 +46,8 @@ func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { res := &filterOrmDecorator{ ormer: delegate, TxBeginner: delegate, - root: func(ctx context.Context, inv *Invocation) { - inv.execute(ctx) + root: func(ctx context.Context, inv *Invocation) []interface{} { + return inv.execute(ctx) }, } @@ -73,7 +74,7 @@ func (f *filterOrmDecorator) Read(md interface{}, cols ...string) error { return f.ReadWithCtx(context.Background(), md, cols...) } -func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) (err error) { +func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "ReadWithCtx", @@ -82,12 +83,13 @@ func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, co mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - err = f.ormer.ReadWithCtx(c, md, cols...) + f: func(c context.Context) []interface{} { + err := f.ormer.ReadWithCtx(c, md, cols...) + return []interface{}{err} }, } - f.root(ctx, inv) - return err + res := f.root(ctx, inv) + return f.convertError(res[0]) } func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error { @@ -95,7 +97,6 @@ func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error } func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { - var err error mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "ReadForUpdateWithCtx", @@ -104,12 +105,13 @@ func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interf mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - err = f.ormer.ReadForUpdateWithCtx(c, md, cols...) + f: func(c context.Context) []interface{} { + err := f.ormer.ReadForUpdateWithCtx(c, md, cols...) + return []interface{}{err} }, } - f.root(ctx, inv) - return err + res := f.root(ctx, inv) + return f.convertError(res[0]) } func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { @@ -117,11 +119,6 @@ func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...s } func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { - var ( - ok bool - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ @@ -131,12 +128,13 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - ok, res, err = f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...) + f: func(c context.Context) []interface{} { + ok, res, err := f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...) + return []interface{}{ok, res, err} }, } - f.root(ctx, inv) - return ok, res, err + res := f.root(ctx, inv) + return res[0].(bool), res[1].(int64), f.convertError(res[2]) } func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { @@ -144,10 +142,6 @@ func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...co } func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { - var ( - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ @@ -157,12 +151,13 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.LoadRelatedWithCtx(c, md, name, args...) + f: func(c context.Context) []interface{} { + res, err := f.ormer.LoadRelatedWithCtx(c, md, name, args...) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { @@ -170,9 +165,6 @@ func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { } func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { - var ( - res QueryM2Mer - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ @@ -182,12 +174,16 @@ func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{} mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res = f.ormer.QueryM2MWithCtx(c, md, name) + f: func(c context.Context) []interface{} { + res := f.ormer.QueryM2MWithCtx(c, md, name) + return []interface{}{res} }, } - f.root(ctx, inv) - return res + res := f.root(ctx, inv) + if res[0] == nil { + return nil + } + return res[0].(QueryM2Mer) } func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { @@ -196,7 +192,6 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { var ( - res QuerySeter name string md interface{} mi *modelInfo @@ -220,28 +215,36 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT TxStartTime: f.txStartTime, Md: md, mi: mi, - f: func(c context.Context) { - res = f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + f: func(c context.Context) []interface{} { + res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + return []interface{}{res} }, } - f.root(ctx, inv) - return res + res := f.root(ctx, inv) + + if res[0] == nil { + return nil + } + return res[0].(QuerySeter) } func (f *filterOrmDecorator) DBStats() *sql.DBStats { - var ( - res *sql.DBStats - ) inv := &Invocation{ Method: "DBStats", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res = f.ormer.DBStats() + f: func(c context.Context) []interface{} { + res := f.ormer.DBStats() + return []interface{}{res} }, } - f.root(context.Background(), inv) - return res + res := f.root(context.Background(), inv) + + if res[0] == nil { + return nil + } + + return res[0].(*sql.DBStats) } func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) { @@ -249,10 +252,6 @@ func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) { } func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { - var ( - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "InsertWithCtx", @@ -261,12 +260,13 @@ func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.InsertWithCtx(c, md) + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertWithCtx(c, md) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { @@ -274,10 +274,6 @@ func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs .. } func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { - var ( - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "InsertOrUpdateWithCtx", @@ -286,12 +282,13 @@ func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md inter mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...) + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, error) { @@ -301,10 +298,8 @@ func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, erro // InsertMultiWithCtx uses the first element's model info func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var ( - res int64 - err error - md interface{} - mi *modelInfo + md interface{} + mi *modelInfo ) sind := reflect.Indirect(reflect.ValueOf(mds)) @@ -322,12 +317,13 @@ func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, m mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.InsertMultiWithCtx(c, bulk, mds) + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertMultiWithCtx(c, bulk, mds) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, error) { @@ -335,10 +331,6 @@ func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, erro } func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - var ( - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "UpdateWithCtx", @@ -347,12 +339,13 @@ func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.UpdateWithCtx(c, md, cols...) + f: func(c context.Context) []interface{} { + res, err := f.ormer.UpdateWithCtx(c, md, cols...) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, error) { @@ -360,10 +353,6 @@ func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, erro } func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - var ( - res int64 - err error - ) mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "DeleteWithCtx", @@ -372,12 +361,13 @@ func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.ormer.DeleteWithCtx(c, md, cols...) + f: func(c context.Context) []interface{} { + res, err := f.ormer.DeleteWithCtx(c, md, cols...) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) } func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter { @@ -385,36 +375,39 @@ func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter { } func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { - var ( - res RawSeter - ) inv := &Invocation{ Method: "RawWithCtx", Args: []interface{}{query, args}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res = f.ormer.RawWithCtx(c, query, args...) + f: func(c context.Context) []interface{} { + res := f.ormer.RawWithCtx(c, query, args...) + return []interface{}{res} }, } - f.root(ctx, inv) - return res + res := f.root(ctx, inv) + + if res[0] == nil { + return nil + } + return res[0].(RawSeter) } func (f *filterOrmDecorator) Driver() Driver { - var ( - res Driver - ) inv := &Invocation{ Method: "Driver", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res = f.ormer.Driver() + f: func(c context.Context) []interface{} { + res := f.ormer.Driver() + return []interface{}{res} }, } - f.root(context.Background(), inv) - return res + res := f.root(context.Background(), inv) + if res[0] == nil { + return nil + } + return res[0].(Driver) } func (f *filterOrmDecorator) Begin() (TxOrmer, error) { @@ -430,22 +423,19 @@ func (f *filterOrmDecorator) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) } func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { - var ( - res TxOrmer - err error - ) inv := &Invocation{ Method: "BeginWithCtxAndOpts", Args: []interface{}{opts}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func(c context.Context) { - res, err = f.TxBeginner.BeginWithCtxAndOpts(c, opts) + f: func(c context.Context) []interface{} { + res, err := f.TxBeginner.BeginWithCtxAndOpts(c, opts) res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c)) + return []interface{}{res, err} }, } - f.root(ctx, inv) - return res, err + res := f.root(ctx, inv) + return res[0].(TxOrmer), f.convertError(res[1]) } func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { @@ -461,58 +451,58 @@ func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx con } func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { - var ( - err error - ) - inv := &Invocation{ Method: "DoTxWithCtxAndOpts", Args: []interface{}{opts, task}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: getTxNameFromCtx(ctx), - f: func(c context.Context) { - err = doTxTemplate(f, c, opts, task) + f: func(c context.Context) []interface{} { + err := doTxTemplate(f, c, opts, task) + return []interface{}{err} }, } - f.root(ctx, inv) - return err + res := f.root(ctx, inv) + return f.convertError(res[0]) } func (f *filterOrmDecorator) Commit() error { - var ( - err error - ) inv := &Invocation{ Method: "Commit", Args: []interface{}{}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func(c context.Context) { - err = f.TxCommitter.Commit() + f: func(c context.Context) []interface{} { + err := f.TxCommitter.Commit() + return []interface{}{err} }, } - f.root(context.Background(), inv) - return err + res := f.root(context.Background(), inv) + return f.convertError(res[0]) } func (f *filterOrmDecorator) Rollback() error { - var ( - err error - ) inv := &Invocation{ Method: "Rollback", Args: []interface{}{}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func(c context.Context) { - err = f.TxCommitter.Rollback() + f: func(c context.Context) []interface{} { + err := f.TxCommitter.Rollback() + return []interface{}{err} }, } - f.root(context.Background(), inv) - return err + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) convertError(v interface{}) error { + if v == nil { + return nil + } + return v.(error) } func getTxNameFromCtx(ctx context.Context) string { diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index b4570a04..d52ce27b 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" "errors" - "github.com/astaxie/beego/pkg/common" "sync" "testing" + "github.com/astaxie/beego/pkg/common" + "github.com/stretchr/testify/assert" ) @@ -31,11 +32,11 @@ func TestFilterOrmDecorator_Read(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "ReadWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) - next(ctx, inv) + return next(ctx, inv) } }) @@ -50,7 +51,7 @@ func TestFilterOrmDecorator_BeginTx(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { if inv.Method == "BeginWithCtxAndOpts" { assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) @@ -69,7 +70,7 @@ func TestFilterOrmDecorator_BeginTx(t *testing.T) { t.Fail() } - next(ctx, inv) + return next(ctx, inv) } }) to, err := od.Begin() @@ -98,11 +99,11 @@ func TestFilterOrmDecorator_BeginTx(t *testing.T) { func TestFilterOrmDecorator_DBStats(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "DBStats", inv.Method) assert.Equal(t, 0, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) - next(ctx, inv) + return next(ctx, inv) } }) res := od.DBStats() @@ -114,11 +115,11 @@ func TestFilterOrmDecorator_Delete(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "DeleteWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) - next(ctx, inv) + return next(ctx, inv) } }) res, err := od.Delete(&FilterTestEntity{}) @@ -130,14 +131,13 @@ func TestFilterOrmDecorator_Delete(t *testing.T) { func TestFilterOrmDecorator_DoTx(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { if inv.Method == "DoTxWithCtxAndOpts" { assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) assert.False(t, inv.InsideTx) } - - next(ctx, inv) + return next(ctx, inv) } }) @@ -156,16 +156,15 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) { }) assert.NotNil(t, err) - od = NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { if inv.Method == "DoTxWithCtxAndOpts" { assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) assert.Equal(t, "do tx name", inv.TxName) assert.False(t, inv.InsideTx) } - next(ctx, inv) + return next(ctx, inv) } }) @@ -179,12 +178,12 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) { func TestFilterOrmDecorator_Driver(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "Driver", inv.Method) assert.Equal(t, 0, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) res := od.Driver() @@ -195,12 +194,12 @@ func TestFilterOrmDecorator_Insert(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "InsertWithCtx", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) @@ -214,12 +213,12 @@ func TestFilterOrmDecorator_InsertMulti(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "InsertMultiWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) @@ -234,12 +233,12 @@ func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "InsertOrUpdateWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) i, err := od.InsertOrUpdate(&FilterTestEntity{}) @@ -251,12 +250,12 @@ func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) { func TestFilterOrmDecorator_LoadRelated(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "LoadRelatedWithCtx", inv.Method) assert.Equal(t, 3, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) i, err := od.LoadRelated(&FilterTestEntity{}, "hello") @@ -268,12 +267,12 @@ func TestFilterOrmDecorator_LoadRelated(t *testing.T) { func TestFilterOrmDecorator_QueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "QueryM2MWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) res := od.QueryM2M(&FilterTestEntity{}, "hello") @@ -284,12 +283,12 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "QueryTableWithCtx", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) res := od.QueryTable(&FilterTestEntity{}) @@ -300,28 +299,28 @@ func TestFilterOrmDecorator_Raw(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "RawWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) res := od.Raw("hh") assert.Nil(t, res) } -func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { +func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "ReadForUpdateWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) err := od.ReadForUpdate(&FilterTestEntity{}) @@ -333,12 +332,12 @@ func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "ReadOrCreateWithCtx", inv.Method) assert.Equal(t, 3, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) - next(ctx, inv) + return next(ctx, inv) } }) ok, i, err := od.ReadOrCreate(&FilterTestEntity{}, "name") diff --git a/pkg/orm/filter_test.go b/pkg/orm/filter_test.go index b2ca4ae1..f9c86039 100644 --- a/pkg/orm/filter_test.go +++ b/pkg/orm/filter_test.go @@ -23,8 +23,8 @@ import ( func TestAddGlobalFilterChain(t *testing.T) { AddGlobalFilterChain(func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) { - + return func(ctx context.Context, inv *Invocation) []interface{} { + return next(ctx, inv) } }) assert.Equal(t, 1, len(globalFilterChains)) diff --git a/pkg/orm/hints/db_hints.go b/pkg/orm/hints/db_hints.go index f708f310..0649ab9f 100644 --- a/pkg/orm/hints/db_hints.go +++ b/pkg/orm/hints/db_hints.go @@ -15,8 +15,9 @@ package hints import ( - "github.com/astaxie/beego/pkg/common" "time" + + "github.com/astaxie/beego/pkg/common" ) const ( diff --git a/pkg/orm/hints/db_hints_test.go b/pkg/orm/hints/db_hints_test.go index 5ab44b08..4e962a8f 100644 --- a/pkg/orm/hints/db_hints_test.go +++ b/pkg/orm/hints/db_hints_test.go @@ -15,9 +15,10 @@ package hints import ( - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestNewHint_time(t *testing.T) { @@ -151,4 +152,4 @@ func TestOrderBy(t *testing.T) { hint := OrderBy(`-ID`) assert.Equal(t, hint.GetValue(), `-ID`) assert.Equal(t, hint.GetKey(), KeyOrderBy) -} \ No newline at end of file +} diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go index 704b13c7..9e7c1974 100644 --- a/pkg/orm/invocation.go +++ b/pkg/orm/invocation.go @@ -29,7 +29,7 @@ type Invocation struct { mi *modelInfo // f is the Orm operation - f func(ctx context.Context) + f func(ctx context.Context) []interface{} // insideTx indicates whether this is inside a transaction InsideTx bool @@ -44,8 +44,8 @@ func (inv *Invocation) GetTableName() string { return "" } -func (inv *Invocation) execute(ctx context.Context) { - inv.f(ctx) +func (inv *Invocation) execute(ctx context.Context) []interface{} { + return inv.f(ctx) } // GetPkFieldName return the primary key of this table diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 7fba89b1..8fcd2e06 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -18,11 +18,12 @@ import ( "database/sql" "encoding/json" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "os" "strings" "time" + "github.com/astaxie/beego/pkg/orm/hints" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -430,7 +431,7 @@ type PtrPk struct { } type StrPk struct { - Id string `orm:"column(id);size(64);pk"` + Id string `orm:"column(id);size(64);pk"` Value string } diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 5d81c764..d7dc3915 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -58,12 +58,13 @@ import ( "database/sql" "errors" "fmt" - "github.com/astaxie/beego/pkg/common" - "github.com/astaxie/beego/pkg/orm/hints" "os" "reflect" "time" + "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/logs" ) diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 734fc738..3a50fcae 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -17,6 +17,7 @@ package orm import ( "context" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" ) diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 687f7099..92410eb2 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -17,9 +17,10 @@ package orm import ( "database/sql" "fmt" - "github.com/pkg/errors" "reflect" "time" + + "github.com/pkg/errors" ) // raw sql string prepared statement diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 7cbce13d..cbe5c9a1 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -21,7 +21,6 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "io/ioutil" "math" "os" @@ -32,6 +31,8 @@ import ( "testing" "time" + "github.com/astaxie/beego/pkg/orm/hints" + "github.com/stretchr/testify/assert" ) @@ -2565,7 +2566,7 @@ func TestStrPkInsert(t *testing.T) { Id: pk, Value: value2, } - + _, err = dORM.InsertOrUpdate(strPkForUpsert, `id`) if err != nil { fmt.Println(err) diff --git a/pkg/orm/types.go b/pkg/orm/types.go index d2b58604..06ba12f2 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -17,9 +17,10 @@ package orm import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/common" "reflect" "time" + + "github.com/astaxie/beego/pkg/common" ) // TableNaming is usually used by model @@ -579,5 +580,5 @@ type dbBaser interface { collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) setval(dbQuerier, *modelInfo, []string) error - GenerateSpecifyIndex(tableName string,useIndex int ,indexes []string) string + GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/pkg/parser.go b/pkg/parser.go index d7ab45f0..bee45d7b 100644 --- a/pkg/parser.go +++ b/pkg/parser.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "go/ast" - "golang.org/x/tools/go/packages" "io/ioutil" "os" "path/filepath" @@ -29,6 +28,8 @@ import ( "strings" "unicode" + "golang.org/x/tools/go/packages" + "github.com/astaxie/beego/pkg/context/param" "github.com/astaxie/beego/pkg/logs" "github.com/astaxie/beego/pkg/utils" diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go index aa548f1a..d84b8df2 100644 --- a/pkg/plugins/auth/basic.go +++ b/pkg/plugins/auth/basic.go @@ -40,7 +40,7 @@ import ( "net/http" "strings" - "github.com/astaxie/beego/pkg" + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/context" ) diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go index a375c593..47a20c8a 100644 --- a/pkg/plugins/authz/authz.go +++ b/pkg/plugins/authz/authz.go @@ -40,10 +40,11 @@ package authz import ( - "github.com/astaxie/beego/pkg" + "net/http" + + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/context" "github.com/casbin/casbin" - "net/http" ) // NewAuthorizer returns the authorizer. diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go index 53e2652a..6cc081f3 100644 --- a/pkg/plugins/authz/authz_test.go +++ b/pkg/plugins/authz/authz_test.go @@ -15,13 +15,14 @@ package authz import ( - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/plugins/auth" - "github.com/casbin/casbin" "net/http" "net/http/httptest" "testing" + + beego "github.com/astaxie/beego/pkg" + "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/plugins/auth" + "github.com/casbin/casbin" ) func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go index a4fb3b39..18bf2bec 100644 --- a/pkg/plugins/cors/cors.go +++ b/pkg/plugins/cors/cors.go @@ -42,7 +42,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg" + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/context" ) diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go index 9757a32b..664d35a7 100644 --- a/pkg/plugins/cors/cors_test.go +++ b/pkg/plugins/cors/cors_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg" + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/context" ) diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go index 6e1fbae6..b68ee012 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/session/redis/sess_redis.go @@ -224,7 +224,7 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } return rp.SessionRead(sid) } diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index d6f051c1..dcdfae85 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -33,13 +33,14 @@ package redis_cluster import ( - "github.com/astaxie/beego/pkg/session" - rediss "github.com/go-redis/redis/v7" "net/http" "strconv" "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/session" + rediss "github.com/go-redis/redis/v7" ) var redispder = &Provider{} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go index 67790096..6721539a 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -33,13 +33,14 @@ package redis_sentinel import ( - "github.com/astaxie/beego/pkg/session" - "github.com/go-redis/redis/v7" "net/http" "strconv" "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/session" + "github.com/go-redis/redis/v7" ) var redispder = &Provider{} diff --git a/pkg/staticfile.go b/pkg/staticfile.go index 27e83395..f8b17fc5 100644 --- a/pkg/staticfile.go +++ b/pkg/staticfile.go @@ -28,7 +28,7 @@ import ( "github.com/astaxie/beego/pkg/context" "github.com/astaxie/beego/pkg/logs" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" ) var errNotStaticRequest = errors.New("request not a static file request") diff --git a/pkg/template_test.go b/pkg/template_test.go index 6e4a27fc..134c2cb2 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -21,7 +21,7 @@ import ( "path/filepath" "testing" - "github.com/elazarl/go-bindata-assetfs" + assetfs "github.com/elazarl/go-bindata-assetfs" "github.com/stretchr/testify/assert" "github.com/astaxie/beego/test" diff --git a/pkg/utils/captcha/captcha.go b/pkg/utils/captcha/captcha.go index f0c37058..62fc26cf 100644 --- a/pkg/utils/captcha/captcha.go +++ b/pkg/utils/captcha/captcha.go @@ -66,7 +66,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg" + beego "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/cache" "github.com/astaxie/beego/pkg/context" "github.com/astaxie/beego/pkg/logs" diff --git a/pkg/validation/validators.go b/pkg/validation/validators.go index 87c83ccd..534a371e 100644 --- a/pkg/validation/validators.go +++ b/pkg/validation/validators.go @@ -16,13 +16,14 @@ package validation import ( "fmt" - "github.com/astaxie/beego/pkg/logs" "reflect" "regexp" "strings" "sync" "time" "unicode/utf8" + + "github.com/astaxie/beego/pkg/logs" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/test/bindata.go b/test/bindata.go index 9fda5075..196ea95c 100644 --- a/test/bindata.go +++ b/test/bindata.go @@ -11,13 +11,14 @@ import ( "bytes" "compress/gzip" "fmt" - "github.com/elazarl/go-bindata-assetfs" "io" "io/ioutil" "os" "path/filepath" "strings" "time" + + assetfs "github.com/elazarl/go-bindata-assetfs" ) func bindataRead(data []byte, name string) ([]byte, error) { From 6c002a3124dce992ee9bc31c9c9389631c1d8728 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Tue, 18 Aug 2020 21:30:11 +0100 Subject: [PATCH 101/207] Update WriteMsg signatures for custom log formatting update --- pkg/logs/accesslog.go | 7 +- pkg/logs/alils/alils.go | 6 +- pkg/logs/conn.go | 7 +- pkg/logs/console.go | 9 +- pkg/logs/file.go | 24 +++--- pkg/logs/jianliao.go | 7 +- pkg/logs/log.go | 181 ++++++++++++++++++++++++++++++++-------- pkg/logs/logger.go | 6 +- pkg/logs/multifile.go | 9 +- pkg/logs/slack.go | 7 +- pkg/logs/smtp.go | 7 +- 11 files changed, 187 insertions(+), 83 deletions(-) diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go index e380c54a..1be711d8 100644 --- a/pkg/logs/accesslog.go +++ b/pkg/logs/accesslog.go @@ -79,5 +79,10 @@ func AccessLog(r *AccessLogRecord, format string) { msg = string(jsonData) } } - beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) + lm := &LogMsg{ + Msg: strings.TrimSpace(msg), + When: time.Now(), + Level: levelLoggerImpl, + } + beeLogger.writeMsg(lm) } diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index fd1a4e28..1bd6b653 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -4,7 +4,6 @@ import ( "encoding/json" "strings" "sync" - "time" "github.com/astaxie/beego/pkg/logs" "github.com/gogo/protobuf/proto" @@ -103,9 +102,8 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { // WriteMsg writes a message in connection. // If connection is down, try to re-connect. -func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { - - if level > c.Level { +func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { + if lm.Level > c.Level { return nil } diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 8b55bde7..e0560fd9 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -18,7 +18,6 @@ import ( "encoding/json" "io" "net" - "time" ) // connWriter implements LoggerInterface. @@ -48,8 +47,8 @@ func (c *connWriter) Init(jsonConfig string) error { // WriteMsg writes message in connection. // If connection is down, try to re-connect. -func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { +func (c *connWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > c.Level { return nil } if c.needToConnectOnMsg() { @@ -63,7 +62,7 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { defer c.innerWriter.Close() } - _, err := c.lg.writeln(when, msg) + _, err := c.lg.writeln(lm) if err != nil { return err } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index b2cc2907..024152aa 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -18,7 +18,6 @@ import ( "encoding/json" "os" "strings" - "time" "github.com/shiena/ansicolor" ) @@ -73,14 +72,14 @@ func (c *consoleWriter) Init(jsonConfig string) error { } // WriteMsg writes message in console. -func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { +func (c *consoleWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > c.Level { return nil } if c.Colorful { - msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) + lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - c.lg.writeln(when, msg) + c.lg.writeln(lm) return nil } diff --git a/pkg/logs/file.go b/pkg/logs/file.go index fbe10b55..23ea4b09 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -144,28 +144,28 @@ func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { } // WriteMsg writes logger message into file. -func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > w.Level { +func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > w.Level { return nil } - hd, d, h := formatTimeHeader(when) - msg = string(hd) + msg + "\n" + hd, d, h := formatTimeHeader(lm.When) + lm.Msg = string(hd) + lm.Msg + "\n" if w.Rotate { w.RLock() - if w.needRotateHourly(len(msg), h) { + if w.needRotateHourly(len(lm.Msg), h) { w.RUnlock() w.Lock() - if w.needRotateHourly(len(msg), h) { - if err := w.doRotate(when); err != nil { + if w.needRotateHourly(len(lm.Msg), h) { + if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } w.Unlock() - } else if w.needRotateDaily(len(msg), d) { + } else if w.needRotateDaily(len(lm.Msg), d) { w.RUnlock() w.Lock() - if w.needRotateDaily(len(msg), d) { - if err := w.doRotate(when); err != nil { + if w.needRotateDaily(len(lm.Msg), d) { + if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } @@ -176,10 +176,10 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { } w.Lock() - _, err := w.fileWriter.Write([]byte(msg)) + _, err := w.fileWriter.Write([]byte(lm.Msg)) if err == nil { w.maxLinesCurLines++ - w.maxSizeCurSize += len(msg) + w.maxSizeCurSize += len(lm.Msg) } w.Unlock() return err diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 71e7e2bf..0e7cfab4 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/url" - "time" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -30,12 +29,12 @@ func (s *JLWriter) Init(jsonconfig string) error { // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. -func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { +func (s *JLWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { return nil } - text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) + text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) form := url.Values{} form.Add("authorName", s.AuthorName) diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 4824918b..3a117327 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -86,7 +86,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { Init(config string) error - WriteMsg(when time.Time, msg string, level int) error + WriteMsg(lm *LogMsg) error Destroy() Flush() } @@ -118,7 +118,7 @@ type BeeLogger struct { asynchronous bool prefix string msgChanLen int64 - msgChan chan *logMsg + msgChan chan *LogMsg signalChan chan string wg sync.WaitGroup outputs []*nameLogger @@ -131,10 +131,12 @@ type nameLogger struct { name string } -type logMsg struct { - level int - msg string - when time.Time +type LogMsg struct { + Level int + Msg string + When time.Time + FilePath string + LineNumber int } var logMsgPool *sync.Pool @@ -166,10 +168,10 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { if len(msgLen) > 0 && msgLen[0] > 0 { bl.msgChanLen = msgLen[0] } - bl.msgChan = make(chan *logMsg, bl.msgChanLen) + bl.msgChan = make(chan *LogMsg, bl.msgChanLen) logMsgPool = &sync.Pool{ New: func() interface{} { - return &logMsg{} + return &LogMsg{} }, } bl.wg.Add(1) @@ -233,9 +235,9 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { return nil } -func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { +func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { - err := l.WriteMsg(when, msg, level) + err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) } @@ -250,15 +252,20 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { if p[len(p)-1] == '\n' { p = p[0 : len(p)-1] } + lm := &LogMsg{ + Msg: string(p), + Level: levelLoggerImpl, + } + // set levelLoggerImpl to ensure all log message will be write out - err = bl.writeMsg(levelLoggerImpl, string(p)) + err = bl.writeMsg(lm) if err == nil { return len(p), err } return 0, err } -func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { +func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { if !bl.init { bl.lock.Lock() bl.setLogger(AdapterConsole) @@ -266,12 +273,11 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error } if len(v) > 0 { - msg = fmt.Sprintf(msg, v...) + lm.Msg = fmt.Sprintf(lm.Msg, v...) } - msg = bl.prefix + " " + msg + lm.Msg = bl.prefix + " " + lm.Msg - when := time.Now() if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) if !ok { @@ -279,29 +285,29 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error line = 0 } _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg + lm.Msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + lm.Msg } //set level info in front of filename info - if logLevel == levelLoggerImpl { + if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly - logLevel = LevelEmergency + lm.Level = LevelEmergency } else { - msg = levelPrefix[logLevel] + " " + msg + lm.Msg = levelPrefix[lm.Level] + " " + lm.Msg } if bl.asynchronous { - lm := logMsgPool.Get().(*logMsg) - lm.level = logLevel - lm.msg = msg - lm.when = when + logM := logMsgPool.Get().(*LogMsg) + logM.Level = lm.Level + logM.Msg = lm.Msg + logM.When = lm.When if bl.outputs != nil { bl.msgChan <- lm } else { logMsgPool.Put(lm) } } else { - bl.writeToLoggers(when, msg, logLevel) + bl.writeToLoggers(lm) } return nil } @@ -345,7 +351,7 @@ func (bl *BeeLogger) startLogger() { for { select { case bm := <-bl.msgChan: - bl.writeToLoggers(bm.when, bm.msg, bm.level) + bl.writeToLoggers(bm) logMsgPool.Put(bm) case sg := <-bl.signalChan: // Now should only send "flush" or "close" to bl.signalChan @@ -370,7 +376,17 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { return } - bl.writeMsg(LevelEmergency, format, v...) + + lm := &LogMsg{ + Level: LevelEmergency, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Alert Log ALERT level message. @@ -378,7 +394,17 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { if LevelAlert > bl.level { return } - bl.writeMsg(LevelAlert, format, v...) + + lm := &LogMsg{ + Level: LevelAlert, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Critical Log CRITICAL level message. @@ -386,7 +412,16 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { if LevelCritical > bl.level { return } - bl.writeMsg(LevelCritical, format, v...) + lm := &LogMsg{ + Level: LevelCritical, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Error Log ERROR level message. @@ -394,7 +429,16 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { if LevelError > bl.level { return } - bl.writeMsg(LevelError, format, v...) + lm := &LogMsg{ + Level: LevelError, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Warning Log WARNING level message. @@ -402,7 +446,16 @@ func (bl *BeeLogger) Warning(format string, v ...interface{}) { if LevelWarn > bl.level { return } - bl.writeMsg(LevelWarn, format, v...) + lm := &LogMsg{ + Level: LevelWarn, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Notice Log NOTICE level message. @@ -410,7 +463,16 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { if LevelNotice > bl.level { return } - bl.writeMsg(LevelNotice, format, v...) + lm := &LogMsg{ + Level: LevelNotice, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Informational Log INFORMATIONAL level message. @@ -418,7 +480,16 @@ func (bl *BeeLogger) Informational(format string, v ...interface{}) { if LevelInfo > bl.level { return } - bl.writeMsg(LevelInfo, format, v...) + lm := &LogMsg{ + Level: LevelInfo, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Debug Log DEBUG level message. @@ -426,7 +497,16 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { if LevelDebug > bl.level { return } - bl.writeMsg(LevelDebug, format, v...) + lm := &LogMsg{ + Level: LevelDebug, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Warn Log WARN level message. @@ -435,7 +515,16 @@ func (bl *BeeLogger) Warn(format string, v ...interface{}) { if LevelWarn > bl.level { return } - bl.writeMsg(LevelWarn, format, v...) + lm := &LogMsg{ + Level: LevelWarn, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Info Log INFO level message. @@ -444,7 +533,16 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { if LevelInfo > bl.level { return } - bl.writeMsg(LevelInfo, format, v...) + lm := &LogMsg{ + Level: LevelInfo, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Trace Log TRACE level message. @@ -453,7 +551,16 @@ func (bl *BeeLogger) Trace(format string, v ...interface{}) { if LevelDebug > bl.level { return } - bl.writeMsg(LevelDebug, format, v...) + lm := &LogMsg{ + Level: LevelDebug, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Flush flush all chan data. @@ -497,7 +604,7 @@ func (bl *BeeLogger) flush() { for { if len(bl.msgChan) > 0 { bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) + bl.writeToLoggers(bm) logMsgPool.Put(bm) continue } diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go index a28bff6f..721c8dc1 100644 --- a/pkg/logs/logger.go +++ b/pkg/logs/logger.go @@ -30,10 +30,10 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { +func (lg *logWriter) writeln(lm *LogMsg) (int, error) { lg.Lock() - h, _, _ := formatTimeHeader(when) - n, err := lg.writer.Write(append(append(h, msg...), '\n')) + h, _, _ := formatTimeHeader(lm.When) + n, err := lg.writer.Write(append(append(h, lm.Msg...), '\n')) lg.Unlock() return n, err } diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 90168274..1cd9e9f8 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -16,7 +16,6 @@ package logs import ( "encoding/json" - "time" ) // A filesLogWriter manages several fileLogWriter @@ -87,14 +86,14 @@ func (f *multiFileLogWriter) Destroy() { } } -func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { +func (f *multiFileLogWriter) WriteMsg(lm *LogMsg) error { if f.fullLogWriter != nil { - f.fullLogWriter.WriteMsg(when, msg, level) + f.fullLogWriter.WriteMsg(lm) } for i := 0; i < len(f.writers)-1; i++ { if f.writers[i] != nil { - if level == f.writers[i].Level { - f.writers[i].WriteMsg(when, msg, level) + if lm.Level == f.writers[i].Level { + f.writers[i].WriteMsg(lm) } } } diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index e78eeab6..dad4f4ea 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/url" - "time" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -26,12 +25,12 @@ func (s *SLACKWriter) Init(jsonconfig string) error { // WriteMsg write message in smtp writer. // Sends an email with subject and only this message. -func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { +func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { return nil } - text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) form := url.Values{} form.Add("payload", text) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 720c2d25..0d2b3c29 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -21,7 +21,6 @@ import ( "net" "net/smtp" "strings" - "time" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -117,8 +116,8 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. -func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { +func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { return nil } @@ -131,7 +130,7 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { // and send the email all in one step. contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + lm.Msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } From fe56de06b55934e39274f04515df27666405c0ac Mon Sep 17 00:00:00 2001 From: IamCathal Date: Tue, 18 Aug 2020 21:30:39 +0100 Subject: [PATCH 102/207] Add enableFullFilePath field to BeeLogger --- pkg/logs/log.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 3a117327..3965eabb 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -115,6 +115,7 @@ type BeeLogger struct { init bool enableFuncCallDepth bool loggerFuncCallDepth int + enableFullFilePath bool asynchronous bool prefix string msgChanLen int64 @@ -654,6 +655,12 @@ func GetLogger(prefixes ...string) *log.Logger { return l } +// EnableFullFilePath enables full file path logging. Disabled by default +// e.g "/home/Documents/GitHub/beego/mainapp/" instead of "mainapp" +func EnableFullFilePath(b bool) { + beeLogger.enableFullFilePath = b +} + // Reset will remove all the adapter func Reset() { beeLogger.Reset() From ac3a549187f63da7e4f5b39543f901174362680d Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 19 Aug 2020 14:21:29 +0100 Subject: [PATCH 103/207] Fix test with new parameters --- pkg/logs/file_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go index 385eac43..7f2a3590 100644 --- a/pkg/logs/file_test.go +++ b/pkg/logs/file_test.go @@ -280,8 +280,13 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Day() } + lm := &LogMsg{ + Msg: "Test message", + Level: LevelDebug, + When: time.Now(), + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(lm) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) From 77ddc3338f02111dde74a4e6bf118095b37e3b83 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 19 Aug 2020 15:07:46 +0100 Subject: [PATCH 104/207] Fix file path logging for enableFullFilePath --- pkg/logs/log.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 3965eabb..37421625 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -39,7 +39,6 @@ import ( "os" "path" "runtime" - "strconv" "strings" "sync" "time" @@ -279,14 +278,25 @@ func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { lm.Msg = bl.prefix + " " + lm.Msg + var ( + file string + line int + ok bool + ) + if bl.enableFuncCallDepth { - _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) if !ok { file = "???" line = 0 } - _, filename := path.Split(file) - lm.Msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + lm.Msg + + if !bl.enableFullFilePath { + _, file = path.Split(file) + } + lm.FilePath = file + lm.LineNumber = line + lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg) } //set level info in front of filename info From 2c16c7b917e5e699da5d9eb4699313a0cc38f012 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 19 Aug 2020 22:08:29 +0800 Subject: [PATCH 105/207] Add more methods to Configer --- pkg/config.go | 3 +- pkg/config/base_config_test.go | 71 ++++++++++++++ pkg/config/config.go | 135 +++++++++++++++++++++++++-- pkg/config/fake.go | 1 + pkg/config/ini.go | 1 + pkg/config/json/json.go | 1 + pkg/config/xml/xml.go | 1 + pkg/config/yaml/yaml.go | 1 + pkg/orm/filter_orm_decorator_test.go | 4 +- 9 files changed, 208 insertions(+), 10 deletions(-) create mode 100644 pkg/config/base_config_test.go diff --git a/pkg/config.go b/pkg/config.go index 0cfb7a4c..e8bde705 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -411,6 +411,7 @@ func LoadAppConfig(adapterName, configPath string) error { } type beegoAppConfig struct { + config.BaseConfiger innerConfig config.Configer } @@ -419,7 +420,7 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err if err != nil { return nil, err } - return &beegoAppConfig{ac}, nil + return &beegoAppConfig{innerConfig: ac}, nil } func (b *beegoAppConfig) Set(key, val string) error { diff --git a/pkg/config/base_config_test.go b/pkg/config/base_config_test.go new file mode 100644 index 00000000..3d37bc91 --- /dev/null +++ b/pkg/config/base_config_test.go @@ -0,0 +1,71 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBaseConfiger_DefaultBool(t *testing.T) { + bc := newBaseConfier("true") + assert.True(t, bc.DefaultBool("key1", false)) + assert.True(t, bc.DefaultBool("key2", true)) +} + +func TestBaseConfiger_DefaultFloat(t *testing.T) { + bc := newBaseConfier("12.3") + assert.Equal(t, 12.3, bc.DefaultFloat("key1", 0.1)) + assert.Equal(t, 0.1, bc.DefaultFloat("key2", 0.1)) +} + +func TestBaseConfiger_DefaultInt(t *testing.T) { + bc := newBaseConfier("10") + assert.Equal(t, 10, bc.DefaultInt("key1", 8)) + assert.Equal(t, 8, bc.DefaultInt("key2", 8)) +} + +func TestBaseConfiger_DefaultInt64(t *testing.T) { + bc := newBaseConfier("64") + assert.Equal(t, int64(64), bc.DefaultInt64("key1", int64(8))) + assert.Equal(t, int64(8), bc.DefaultInt64("key2", int64(8))) +} + +func TestBaseConfiger_DefaultString(t *testing.T) { + bc := newBaseConfier("Hello") + assert.Equal(t, "Hello", bc.DefaultString("key1", "world")) + assert.Equal(t, "world", bc.DefaultString("key2", "world")) +} + +func TestBaseConfiger_DefaultStrings(t *testing.T) { + bc := newBaseConfier("Hello;world") + assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings("key1", []string{"world"})) + assert.Equal(t, []string{"world"}, bc.DefaultStrings("key2", []string{"world"})) +} + +func newBaseConfier(str1 string) *BaseConfiger { + return &BaseConfiger{ + reader: func(key string) (string, error) { + if key == "key1" { + return str1, nil + } else { + return "", errors.New("mock error") + } + + }, + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index bfd79e85..b17f6208 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -15,7 +15,7 @@ // Package config is used to parse config. // Usage: // import "github.com/astaxie/beego/config" -//Examples. +// Examples. // // cnf, err := config.NewConfig("ini", "config.conf") // @@ -37,36 +37,157 @@ // cnf.DIY(key string) (interface{}, error) // cnf.GetSection(section string) (map[string]string, error) // cnf.SaveConfigFile(filename string) error -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package config import ( + "errors" "fmt" "os" "reflect" + "strconv" + "strings" "time" ) // Configer defines how to get and set value from configuration raw data. type Configer interface { - Set(key, val string) error //support section::key type in given key when using ini type. - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - Strings(key string) []string //get string slice + // support section::key type in given key when using ini type. + Set(key, val string) error + + // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + String(key string) string + // get string slice + Strings(key string) []string Int(key string) (int, error) Int64(key string) (int64, error) Bool(key string) (bool, error) Float(key string) (float64, error) - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultStrings(key string, defaultVal []string) []string //get string slice + // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultString(key string, defaultVal string) string + // get string slice + DefaultStrings(key string, defaultVal []string) []string DefaultInt(key string, defaultVal int) int DefaultInt64(key string, defaultVal int64) int64 DefaultBool(key string, defaultVal bool) bool DefaultFloat(key string, defaultVal float64) float64 DIY(key string) (interface{}, error) GetSection(section string) (map[string]string, error) + + Unmarshaler(obj interface{}) error + Sub(key string) (Configer, error) + OnChange(fn func(cfg Configer)) + // GetByPrefix(prefix string) ([]byte, error) + // GetSerializer() Serializer SaveConfigFile(filename string) error } +type BaseConfiger struct { + // The reader should support key like "a.b.c" + reader func(key string) (string, error) +} + +func (c *BaseConfiger) Int(key string) (int, error) { + res, err := c.reader(key) + if err != nil { + return 0, err + } + return strconv.Atoi(res) +} + +func (c *BaseConfiger) Int64(key string) (int64, error) { + res, err := c.reader(key) + if err != nil { + return 0, err + } + return strconv.ParseInt(res, 10, 64) +} + +func (c *BaseConfiger) Bool(key string) (bool, error) { + res, err := c.reader(key) + if err != nil { + return false, err + } + return strconv.ParseBool(res) +} + +func (c *BaseConfiger) Float(key string) (float64, error) { + res, err := c.reader(key) + if err != nil { + return 0, err + } + return strconv.ParseFloat(res, 64) +} + +func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { + if res := c.String(key); res != "" { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { + if res := c.Strings(key); len(res) > 0 { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultInt(key string, defaultVal int) int { + if res, err := c.Int(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultInt64(key string, defaultVal int64) int64 { + if res, err := c.Int64(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { + if res, err := c.Bool(key); err == nil { + return res + } + return defaultVal +} +func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { + if res, err := c.Float(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) String(key string) string { + res, _ := c.reader(key) + return res +} + +func (c *BaseConfiger) Strings(key string) []string { + res, err := c.reader(key) + if err != nil || res == "" { + return nil + } + return strings.Split(res, ";") +} + +// TODO remove this before release v2.0.0 +func (c *BaseConfiger) Unmarshaler(obj interface{}) error { + return errors.New("unsupported operation") +} + +// TODO remove this before release v2.0.0 +func (c *BaseConfiger) Sub(key string) (Configer, error) { + return nil, errors.New("unsupported operation") +} + +// TODO remove this before release v2.0.0 +func (c *BaseConfiger) OnChange(fn func(cfg Configer)) { + // do nothing +} + // Config is the adapter interface for parsing config file to get raw data to Configer. type Config interface { Parse(key string) (Configer, error) diff --git a/pkg/config/fake.go b/pkg/config/fake.go index d21ab820..ddbc99b8 100644 --- a/pkg/config/fake.go +++ b/pkg/config/fake.go @@ -21,6 +21,7 @@ import ( ) type fakeConfigContainer struct { + BaseConfiger data map[string]string } diff --git a/pkg/config/ini.go b/pkg/config/ini.go index f5921308..0bef67d4 100644 --- a/pkg/config/ini.go +++ b/pkg/config/ini.go @@ -225,6 +225,7 @@ func (ini *IniConfig) ParseData(data []byte) (Configer, error) { // IniConfigContainer is a config which represents the ini configuration. // When set and get value, support key as section:name type. type IniConfigContainer struct { + BaseConfiger data map[string]map[string]string // section=> key:val sectionComment map[string]string // section : comment keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. diff --git a/pkg/config/json/json.go b/pkg/config/json/json.go index ede3cce5..876077e1 100644 --- a/pkg/config/json/json.go +++ b/pkg/config/json/json.go @@ -69,6 +69,7 @@ func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { // JSONConfigContainer is a config which represents the json configuration. // Only when get value, support key as section:name type. type JSONConfigContainer struct { + config.BaseConfiger data map[string]interface{} sync.RWMutex } diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go index d8c018e6..9b5ec791 100644 --- a/pkg/config/xml/xml.go +++ b/pkg/config/xml/xml.go @@ -74,6 +74,7 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { // ConfigContainer is a Config which represents the xml configuration. type ConfigContainer struct { + config.BaseConfiger data map[string]interface{} sync.Mutex } diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go index 63a30208..5c77e88f 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/config/yaml/yaml.go @@ -118,6 +118,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { // ConfigContainer is a config which represents the yaml configuration. type ConfigContainer struct { + config.BaseConfiger data map[string]interface{} sync.RWMutex } diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index d52ce27b..47f20854 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -115,7 +115,7 @@ func TestFilterOrmDecorator_Delete(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { - return func(ctx context.Context, inv *Invocation) []interface{} { + return func(ctx context.Context, inv *Invocation) []interface{} { assert.Equal(t, "DeleteWithCtx", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) @@ -311,7 +311,7 @@ func TestFilterOrmDecorator_Raw(t *testing.T) { assert.Nil(t, res) } -func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { +func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { From ff5ac3adf409e746376e5b14e10a4e105869bda7 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 19 Aug 2020 16:13:42 +0100 Subject: [PATCH 106/207] Update signature of WriteMsg in es.go --- pkg/logs/alils/alils.go | 8 ++++---- pkg/logs/es/es.go | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 1bd6b653..6c1464f2 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -113,7 +113,7 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if c.withMap { // Topic,LogGroup - strs := strings.SplitN(msg, Delimiter, 2) + strs := strings.SplitN(lm.Msg, Delimiter, 2) if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") topic = strs[0][pos+1 : len(strs[0])] @@ -123,11 +123,11 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { // send to empty Topic if lg == nil { - content = msg + content = lm.Msg lg = c.group[0] } } else { - content = msg + content = lm.Msg lg = c.group[0] } @@ -137,7 +137,7 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { } l := &Log{ - Time: proto.Uint32(uint32(when.Unix())), + Time: proto.Uint32(uint32(lm.When.Unix())), Contents: []*LogContent{ c1, }, diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 7542b577..5c91b2ed 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -60,14 +60,14 @@ func (el *esLogger) Init(jsonconfig string) error { } // WriteMsg writes the msg and level into es -func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { - if level > el.Level { +func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { + if lm.Level > el.Level { return nil } idx := LogDocument{ - Timestamp: when.Format(time.RFC3339), - Msg: msg, + Timestamp: lm.When.Format(time.RFC3339), + Msg: lm.Msg, } body, err := json.Marshal(idx) @@ -75,7 +75,7 @@ func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { return err } req := esapi.IndexRequest{ - Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), + Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()), DocumentType: "logs", Body: strings.NewReader(string(body)), } From 9fe353dd0bfd5fdbcf3c30a22d6cc8e5df60396a Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Wed, 19 Aug 2020 15:40:52 +0800 Subject: [PATCH 107/207] Fix issue 3886 --- pkg/orm/orm_raw.go | 36 +++++++++++++++++++++++++----------- pkg/orm/orm_test.go | 18 ++++++++++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 92410eb2..c2539147 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -383,19 +383,33 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { } } } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } } } + + // init call the recursive function + recursiveSetField(ind) } } else { diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index cbe5c9a1..40314ab4 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -1742,6 +1742,24 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(pid, nil)) + type Embeded struct { + Email string + } + type queryRowNoModelTest struct { + Id int + EmbedField Embeded + } + + cols = []string{ + "id", "email", + } + var row queryRowNoModelTest + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&row) + throwFail(t, err) + throwFail(t, AssertIs(row.Id, 4)) + throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com")) + // test for sql.Null* fields nData := &DataNull{ NullString: sql.NullString{String: "test sql.null", Valid: true}, From 6bdedff45714b42f5bca6e4959b4771ad031fa9b Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:00:35 +0100 Subject: [PATCH 108/207] LogFormatter Implementation --- pkg/logs/conn.go | 7 ++++++- pkg/logs/console.go | 18 +++++++++++++++++- pkg/logs/es/es.go | 4 ++++ pkg/logs/file.go | 4 ++++ pkg/logs/jianliao.go | 4 ++++ pkg/logs/log.go | 12 ++++++++++++ pkg/logs/logger.go | 6 +++--- pkg/logs/multifile.go | 4 ++++ pkg/logs/slack.go | 4 ++++ pkg/logs/smtp.go | 4 ++++ 10 files changed, 62 insertions(+), 5 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index e0560fd9..79ab410c 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -39,6 +39,10 @@ func NewConn() Logger { return conn } +func (c *connWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init initializes a connection writer with json config. // json config only needs they "level" key func (c *connWriter) Init(jsonConfig string) error { @@ -62,7 +66,8 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - _, err := c.lg.writeln(lm) + msg := c.Format(lm) + _, err := c.lg.writeln(msg) if err != nil { return err } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 024152aa..86db6178 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -52,6 +52,20 @@ type consoleWriter struct { Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } +func (c *consoleWriter) Format(lm *LogMsg) string { + msg := lm.Msg + + if c.Colorful { + msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } + + h, _, _ := formatTimeHeader(lm.When) + bytes := append(append(h, msg...), '\n') + + return "eee" + string(bytes) + +} + // NewConsole creates ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { cw := &consoleWriter{ @@ -76,10 +90,12 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } + // fmt.Printf("Formatted: %s\n\n", c.fmtter.Format(lm)) if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - c.lg.writeln(lm) + msg := c.Format(lm) + c.lg.writeln(msg) return nil } diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 5c91b2ed..b70e5cf3 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -35,6 +35,10 @@ type esLogger struct { Level int `json:"level"` } +func (el *esLogger) Format(lm *logs.LogMsg) string { + return lm.Msg +} + // {"dsn":"http://localhost:9200/","level":1} func (el *esLogger) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), el) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 23ea4b09..6b33ebb1 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -89,6 +89,10 @@ func newFileWriter() Logger { return w } +func (w *fileLogWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init file logger with json config. // jsonConfig like: // { diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 0e7cfab4..a108342c 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -27,6 +27,10 @@ func (s *JLWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), s) } +func (s *JLWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. func (s *JLWriter) WriteMsg(lm *LogMsg) error { diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 37421625..d47173e5 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -86,6 +86,7 @@ type newLoggerFunc func() Logger type Logger interface { Init(config string) error WriteMsg(lm *LogMsg) error + Format(lm *LogMsg) string Destroy() Flush() } @@ -128,6 +129,8 @@ const defaultAsyncMsgLen = 1e3 type nameLogger struct { Logger + // Formatter func(*LogMsg) string + LogFormatter name string } @@ -139,6 +142,10 @@ type LogMsg struct { LineNumber int } +type LogFormatter interface { + Format(lm *LogMsg) string +} + var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. @@ -179,6 +186,10 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } +func Format(lm *LogMsg) string { + return lm.Msg +} + // SetLogger provides a given logger adapter into BeeLogger with config string. // config must in in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { @@ -237,6 +248,7 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { + // fmt.Println("Formatted: ", l.Format(lm)) err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go index 721c8dc1..d8b334d4 100644 --- a/pkg/logs/logger.go +++ b/pkg/logs/logger.go @@ -30,10 +30,10 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(lm *LogMsg) (int, error) { +func (lg *logWriter) writeln(msg string) (int, error) { lg.Lock() - h, _, _ := formatTimeHeader(lm.When) - n, err := lg.writer.Write(append(append(h, lm.Msg...), '\n')) + msg += "\n" + n, err := lg.writer.Write([]byte(msg)) lg.Unlock() return n, err } diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 1cd9e9f8..0650c99d 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -78,6 +78,10 @@ func (f *multiFileLogWriter) Init(config string) error { return nil } +func (f *multiFileLogWriter) Format(lm *LogMsg) string { + return lm.Msg +} + func (f *multiFileLogWriter) Destroy() { for i := 0; i < len(f.writers); i++ { if f.writers[i] != nil { diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index dad4f4ea..c31f9330 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -18,6 +18,10 @@ func newSLACKWriter() Logger { return &SLACKWriter{Level: LevelTrace} } +func (s *SLACKWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // Init SLACKWriter with json config string func (s *SLACKWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), s) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 0d2b3c29..beadb0d7 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -114,6 +114,10 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return client.Quit() } +func (s *SMTPWriter) Format(lm *LogMsg) string { + return lm.Msg +} + // WriteMsg writes message in smtp writer. // Sends an email with subject and only this message. func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { From 705e091593a49a09904c75896aec1f85aa3c8862 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:06:51 +0100 Subject: [PATCH 109/207] Add format call before logging --- pkg/logs/es/es.go | 2 +- pkg/logs/file.go | 7 ++++--- pkg/logs/jianliao.go | 3 +-- pkg/logs/slack.go | 4 ++-- pkg/logs/smtp.go | 4 +++- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index b70e5cf3..06dfece1 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -71,7 +71,7 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { idx := LogDocument{ Timestamp: lm.When.Format(time.RFC3339), - Msg: lm.Msg, + Msg: el.Format(lm), } body, err := json.Marshal(idx) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 6b33ebb1..366fbcf2 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -153,7 +153,8 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { return nil } hd, d, h := formatTimeHeader(lm.When) - lm.Msg = string(hd) + lm.Msg + "\n" + msg := w.Format(lm) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) if w.Rotate { w.RLock() if w.needRotateHourly(len(lm.Msg), h) { @@ -180,10 +181,10 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { } w.Lock() - _, err := w.fileWriter.Write([]byte(lm.Msg)) + _, err := w.fileWriter.Write([]byte(msg)) if err == nil { w.maxLinesCurLines++ - w.maxSizeCurSize += len(lm.Msg) + w.maxSizeCurSize += len(msg) } w.Unlock() return err diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index a108342c..6830bade 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -38,8 +38,7 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) - + text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) form := url.Values{} form.Add("authorName", s.AuthorName) form.Add("title", s.Title) diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index c31f9330..c0584f72 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -33,8 +33,8 @@ func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { if lm.Level > s.Level { return nil } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.Msg) + msg := s.Format(lm) + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), msg) form := url.Values{} form.Add("payload", text) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index beadb0d7..d992b279 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -130,11 +130,13 @@ func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { // Set up authentication information. auth := s.getSMTPAuth(hp[0]) + msg := s.Format(lm) + // Connect to the server, authenticate, set the sender and recipient, // and send the email all in one step. contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + lm.Msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } From e1da804b2ba54572dc44c76b963d63068a92bad8 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Thu, 20 Aug 2020 19:15:27 +0100 Subject: [PATCH 110/207] Add format func to alils --- pkg/logs/alils/alils.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 6c1464f2..2c83e4ee 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -100,6 +100,10 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { return nil } +func (c *aliLSWriter) Format(lm *logs.LogMsg) string { + return lm.Msg +} + // WriteMsg writes a message in connection. // If connection is down, try to re-connect. func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { From 7a94996e22fe3081dda665a701e292ce9a87ba4d Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Mon, 24 Aug 2020 20:23:54 +0800 Subject: [PATCH 112/207] Feature: implement the time precison for time.Time type --- pkg/client/orm/cmd_utils.go | 7 +++++- pkg/client/orm/db_mysql.go | 37 ++++++++++++++--------------- pkg/client/orm/db_oracle.go | 35 ++++++++++++++-------------- pkg/client/orm/db_postgres.go | 41 +++++++++++++++++---------------- pkg/client/orm/models_info_f.go | 15 +++++++++++- pkg/client/orm/models_utils.go | 1 + 6 files changed, 79 insertions(+), 57 deletions(-) diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go index 692a079f..105c0b69 100644 --- a/pkg/client/orm/cmd_utils.go +++ b/pkg/client/orm/cmd_utils.go @@ -66,7 +66,12 @@ checkColumn: case TypeDateField: col = T["time.Time-date"] case TypeDateTimeField: - col = T["time.Time"] + if fi.timePrecision == nil { + col = T["time.Time"] + } else { + s := T["time.Time-precision"] + col = fmt.Sprintf(s, *fi.timePrecision) + } case TypeBitField: col = T["int8"] case TypeSmallIntegerField: diff --git a/pkg/client/orm/db_mysql.go b/pkg/client/orm/db_mysql.go index d934d842..f602fd0a 100644 --- a/pkg/client/orm/db_mysql.go +++ b/pkg/client/orm/db_mysql.go @@ -42,24 +42,25 @@ var mysqlOperators = map[string]string{ // mysql column field types. var mysqlTypes = map[string]string{ - "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "longtext", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "time.Time-precision": "datetime(%d)", } // mysql dbBaser implementation. diff --git a/pkg/client/orm/db_oracle.go b/pkg/client/orm/db_oracle.go index d8d8c6c1..7c1bf1b3 100644 --- a/pkg/client/orm/db_oracle.go +++ b/pkg/client/orm/db_oracle.go @@ -33,23 +33,24 @@ var oracleOperators = map[string]string{ // oracle column field types. var oracleTypes = map[string]string{ - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "VARCHAR2(%d)", - "string-char": "CHAR(%d)", - "string-text": "VARCHAR2(%d)", - "time.Time-date": "DATE", - "time.Time": "TIMESTAMP", - "int8": "INTEGER", - "int16": "INTEGER", - "int32": "INTEGER", - "int64": "INTEGER", - "uint8": "INTEGER", - "uint16": "INTEGER", - "uint32": "INTEGER", - "uint64": "INTEGER", - "float64": "NUMBER", - "float64-decimal": "NUMBER(%d, %d)", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", + "time.Time-precision": "TIMESTAMP(%d)", } // oracle dbBaser diff --git a/pkg/client/orm/db_postgres.go b/pkg/client/orm/db_postgres.go index 35471ddc..12431d6e 100644 --- a/pkg/client/orm/db_postgres.go +++ b/pkg/client/orm/db_postgres.go @@ -39,26 +39,27 @@ var postgresOperators = map[string]string{ // postgresql column field types. var postgresTypes = map[string]string{ - "auto": "serial NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "timestamp with time zone", - "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, - "uint16": `integer CHECK("%COL%" >= 0)`, - "uint32": `bigint CHECK("%COL%" >= 0)`, - "uint64": `bigint CHECK("%COL%" >= 0)`, - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", - "json": "json", - "jsonb": "jsonb", + "auto": "serial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", + "time.Time-precision": "timestamp(%d) with time zone", } // postgresql dbBaser. diff --git a/pkg/client/orm/models_info_f.go b/pkg/client/orm/models_info_f.go index 7044b0bd..7152fada 100644 --- a/pkg/client/orm/models_info_f.go +++ b/pkg/client/orm/models_info_f.go @@ -137,6 +137,7 @@ type fieldInfo struct { isFielder bool // implement Fielder interface onDelete string description string + timePrecision *int } // new field info @@ -177,7 +178,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN decimals := tags["decimals"] size := tags["size"] onDelete := tags["on_delete"] - + precision := tags["precision"] initial.Clear() if v, ok := tags["default"]; ok { initial.Set(v) @@ -377,6 +378,18 @@ checkType: fi.index = false fi.unique = false case TypeTimeField, TypeDateField, TypeDateTimeField: + if fieldType == TypeDateTimeField { + if precision != "" { + v, e := StrTo(precision).Int() + if e != nil { + err = fmt.Errorf("convert %s to int error:%v", precision, e) + } else { + fi.timePrecision = &v + } + } + + } + if attrs["auto_now"] { fi.autoNow = true } else if attrs["auto_now_add"] { diff --git a/pkg/client/orm/models_utils.go b/pkg/client/orm/models_utils.go index 71127a6b..6fca59a9 100644 --- a/pkg/client/orm/models_utils.go +++ b/pkg/client/orm/models_utils.go @@ -45,6 +45,7 @@ var supportTag = map[string]int{ "on_delete": 2, "type": 2, "description": 2, + "precision": 2, } // get reflect.Type name with package path. From ed1d2c7f6e2d8589daf69aedf9a9d6e7c5d76d86 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:22:38 +0100 Subject: [PATCH 113/207] Add custom logging format functionality and global formatter functionality --- pkg/logs/conn.go | 25 +++++++++---- pkg/logs/console.go | 30 ++++++++++++--- pkg/logs/file.go | 20 +++++++++- pkg/logs/jianliao.go | 23 ++++++++---- pkg/logs/log.go | 85 ++++++++++++++++++++++++++++++++++++++++--- pkg/logs/multifile.go | 18 ++++++--- pkg/logs/slack.go | 15 ++++++-- pkg/logs/smtp.go | 11 +++++- 8 files changed, 190 insertions(+), 37 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 79ab410c..e11909a0 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -23,13 +23,15 @@ import ( // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -45,7 +47,14 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string) error { +func (c *connWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonConfig), c) } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 86db6178..a928de7d 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -47,9 +47,11 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { @@ -62,7 +64,7 @@ func (c *consoleWriter) Format(lm *LogMsg) string { h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') - return "eee" + string(bytes) + return string(bytes) } @@ -78,10 +80,18 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string) error { +func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } + if len(jsonConfig) == 0 { return nil } + return json.Unmarshal([]byte(jsonConfig), c) } @@ -94,7 +104,15 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - msg := c.Format(lm) + + msg := "" + + if c.UseCustomFormatter { + msg = c.CustomFormatter(lm) + } else { + msg = c.Format(lm) + } + c.lg.writeln(msg) return nil } diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 366fbcf2..4576e19d 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -60,6 +60,9 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string + Rotate bool `json:"rotate"` Level int `json:"level"` @@ -104,7 +107,14 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string) error { +func (w *fileLogWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + w.UseCustomFormatter = true + w.CustomFormatter = elem + } + } + err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { return err @@ -153,7 +163,13 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { return nil } hd, d, h := formatTimeHeader(lm.When) - msg := w.Format(lm) + msg := "" + if w.UseCustomFormatter { + msg = w.CustomFormatter(lm) + } else { + msg = w.Format(lm) + } + msg = fmt.Sprintf("%s %s\n", string(hd), msg) if w.Rotate { w.RLock() diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 6830bade..9877bed6 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -9,12 +9,14 @@ import ( // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // newJLWriter creates jiaoliao writer. @@ -23,7 +25,14 @@ func newJLWriter() Logger { } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { +func (s *JLWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } diff --git a/pkg/logs/log.go b/pkg/logs/log.go index d47173e5..fd8fca63 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -84,7 +84,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string) error + Init(config string, LogFormatter ...func(*LogMsg) string) error WriteMsg(lm *LogMsg) error Format(lm *LogMsg) string Destroy() @@ -115,6 +115,7 @@ type BeeLogger struct { init bool enableFuncCallDepth bool loggerFuncCallDepth int + globalFormatter func(*LogMsg) string enableFullFilePath bool asynchronous bool prefix string @@ -129,8 +130,6 @@ const defaultAsyncMsgLen = 1e3 type nameLogger struct { Logger - // Formatter func(*LogMsg) string - LogFormatter name string } @@ -206,7 +205,16 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() - err := lg.Init(config) + var err error + + // Global formatter overrides the default set formatter + // but not adapter specific formatters set with logs.SetLoggerWithOpts() + if bl.globalFormatter != nil { + err = lg.Init(config, bl.globalFormatter) + } else { + err = lg.Init(config) + } + if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -248,7 +256,6 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { - // fmt.Println("Formatted: ", l.Format(lm)) err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) @@ -394,6 +401,74 @@ func (bl *BeeLogger) startLogger() { } } +// SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON +// such as: {"interval":360} +func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) + } + } + + logAdapter, ok := adapters[adapterName] + if !ok { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + + if formatterFunc == nil { + return fmt.Errorf("No formatter set for %s log adapter", adapterName) + } + + lg := logAdapter() + err := lg.Init(config, formatterFunc) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + + bl.outputs = append(bl.outputs, &nameLogger{ + name: adapterName, + Logger: lg, + }) + + return nil +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLoggerWithOpts(adapterName, formatterFunc, configs...) +} + +// SetLoggerWIthOpts sets a given log adapter with a custom log adapter. +// Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} +// where FormatFunc has the signature func(*LogMsg) string +func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { + err := beeLogger.SetLoggerWithOpts(adapter, formatterFunc, config...) + if err != nil { + log.Fatal(err) + } + return nil + +} + +func (bl *BeeLogger) setGlobalFormatter(fmtter func(*LogMsg) string) error { + bl.globalFormatter = fmtter + return nil +} + +// SetGlobalFormatter sets the global formatter for all log adapters +// This overrides and other individually set adapter +func SetGlobalFormatter(fmtter func(*LogMsg) string) error { + return beeLogger.setGlobalFormatter(fmtter) +} + // Emergency Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 0650c99d..bcd4dd4e 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -24,9 +24,11 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -44,7 +46,14 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(config string) error { +func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + f.UseCustomFormatter = true + f.CustomFormatter = elem + } + } + writer := newFileWriter().(*fileLogWriter) err := writer.Init(config) if err != nil { @@ -74,7 +83,6 @@ func (f *multiFileLogWriter) Init(config string) error { } } } - return nil } diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index c0584f72..9407b48a 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -9,8 +9,10 @@ import ( // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // newSLACKWriter creates jiaoliao writer. @@ -23,7 +25,14 @@ func (s *SLACKWriter) Format(lm *LogMsg) string { } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { +func (s *SLACKWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index d992b279..b81be68f 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -32,6 +32,8 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*LogMsg) string } // NewSMTPWriter creates the smtp writer. @@ -50,7 +52,14 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonconfig string) error { +func (s *SMTPWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + s.UseCustomFormatter = true + s.CustomFormatter = elem + } + } + return json.Unmarshal([]byte(jsonconfig), s) } From 48a98ec1a5c7aeb7674b2b09885f3dd9d9e575d4 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:39:53 +0100 Subject: [PATCH 114/207] Fix init for alils.go --- pkg/logs/alils/alils.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 2c83e4ee..2300f8f8 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -32,11 +32,13 @@ type Config struct { // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + UseCustomFormatter bool + CustomFormatter func(*logs.LogMsg) string Config } @@ -48,7 +50,14 @@ func NewAliLS() logs.Logger { } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { +func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) string) (err error) { + + for _, elem := range LogFormatter { + if elem != nil { + c.UseCustomFormatter = true + c.CustomFormatter = elem + } + } json.Unmarshal([]byte(jsonConfig), c) @@ -135,6 +144,12 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } + if c.UseCustomFormatter { + content = c.CustomFormatter(lm) + } else { + content = c.Format(lm) + } + c1 := &LogContent{ Key: proto.String("msg"), Value: proto.String(content), From c5970766a35cbc588c3b59b21d626e296894ffe1 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:41:39 +0100 Subject: [PATCH 115/207] Add init to es.go --- pkg/logs/es/es.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 06dfece1..4dfc4160 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -31,8 +31,10 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` + DSN string `json:"dsn"` + Level int `json:"level"` + UseCustomFormatter bool + CustomFormatter func(*logs.LogMsg) string } func (el *esLogger) Format(lm *logs.LogMsg) string { @@ -40,7 +42,14 @@ func (el *esLogger) Format(lm *logs.LogMsg) string { } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { +func (el *esLogger) Init(jsonconfig string, LogFormatter ...func(*logs.LogMsg) string) error { + for _, elem := range LogFormatter { + if elem != nil { + el.UseCustomFormatter = true + el.CustomFormatter = elem + } + } + err := json.Unmarshal([]byte(jsonconfig), el) if err != nil { return err @@ -69,9 +78,16 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { return nil } + msg := "" + if el.UseCustomFormatter { + msg = el.CustomFormatter(lm) + } else { + msg = el.Format(lm) + } + idx := LogDocument{ Timestamp: lm.When.Format(time.RFC3339), - Msg: el.Format(lm), + Msg: msg, } body, err := json.Marshal(idx) From c2471b22ad04bf1623aab8fc3dc8f2d5f6461a88 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 24 Aug 2020 20:54:55 +0100 Subject: [PATCH 116/207] Remove ineffectual assignments Removed 3 lines due to warning from test suite saying these lines had innefectual assignments --- pkg/logs/alils/alils.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 2300f8f8..183d9b24 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -130,17 +130,14 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] lg = c.groupMap[topic] } // send to empty Topic if lg == nil { - content = lm.Msg lg = c.group[0] } } else { - content = lm.Msg lg = c.group[0] } From 1cb0ff560d2c37287d33957b37348ca60070d463 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 24 Aug 2020 14:44:35 +0000 Subject: [PATCH 118/207] Support precision --- pkg/client/orm/db.go | 9 ++- pkg/client/orm/models_test.go | 114 +++++++++++++++++----------------- pkg/doc.go | 2 +- 3 files changed, 67 insertions(+), 58 deletions(-) diff --git a/pkg/client/orm/db.go b/pkg/client/orm/db.go index 5905325d..de56e229 100644 --- a/pkg/client/orm/db.go +++ b/pkg/client/orm/db.go @@ -1355,7 +1355,14 @@ setValue: t time.Time err error ) - if len(s) >= 19 { + + if fi.timePrecision != nil && len(s) >= (20+*fi.timePrecision) { + layout := formatDateTime + "." + for i := 0; i < *fi.timePrecision; i++ { + layout += "0" + } + t, err = time.ParseInLocation(layout, s, tz) + } else if len(s) >= 19 { s = s[:19] t, err = time.ParseInLocation(formatDateTime, s, tz) } else if len(s) >= 10 { diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 8a60c36b..b217dde4 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -145,55 +145,56 @@ type Data struct { } type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + DateTimePrecision time.Time `orm:"null;type(datetime);precision(4)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` } type String string @@ -297,13 +298,14 @@ func NewProfile() *Profile { } type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.PostTags)"` + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + UpdatedPrecision time.Time `orm:"auto_now;type(datetime);precision(4)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { diff --git a/pkg/doc.go b/pkg/doc.go index 2e4378c8..2d9c2bfe 100644 --- a/pkg/doc.go +++ b/pkg/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From b83094ac1e98bb296331023a794b3097c3cdce69 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Wed, 26 Aug 2020 11:51:05 +0800 Subject: [PATCH 119/207] supplement datetime precision UT --- pkg/client/orm/cmd_utils.go | 6 ++++-- pkg/client/orm/models_test.go | 15 +++++++++++++++ pkg/client/orm/orm_test.go | 20 ++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go index 105c0b69..f6b25e8d 100644 --- a/pkg/client/orm/cmd_utils.go +++ b/pkg/client/orm/cmd_utils.go @@ -66,12 +66,14 @@ checkColumn: case TypeDateField: col = T["time.Time-date"] case TypeDateTimeField: - if fi.timePrecision == nil { + // the precision of sqlite is not implemented + if al.Driver == 2 || fi.timePrecision == nil { col = T["time.Time"] - } else { + }else { s := T["time.Time-precision"] col = fmt.Sprintf(s, *fi.timePrecision) } + case TypeBitField: col = T["int8"] case TypeSmallIntegerField: diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 8a60c36b..2f96db1b 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -241,6 +241,21 @@ type UserBig struct { Name string } +type TM struct { + ID int `orm:"column(id)"` + TMPrecision1 time.Time `orm:"type(datetime);precision(3)"` + TMPrecision2 time.Time `orm:"auto_now_add;type(datetime);precision(4)"` +} + +func (t *TM) TableName() string { + return "tm" +} + +func NewTM() *TM { + obj := new(TM) + return obj +} + type User struct { ID int `orm:"column(id)"` UserName string `orm:"size(30);unique"` diff --git a/pkg/client/orm/orm_test.go b/pkg/client/orm/orm_test.go index 92374e02..8c4bf55d 100644 --- a/pkg/client/orm/orm_test.go +++ b/pkg/client/orm/orm_test.go @@ -204,6 +204,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(PtrPk)) RegisterModel(new(Index)) RegisterModel(new(StrPk)) + RegisterModel(new(TM)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -230,6 +231,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(PtrPk)) RegisterModel(new(Index)) RegisterModel(new(StrPk)) + RegisterModel(new(TM)) BootStrap() @@ -313,6 +315,24 @@ func TestDataTypes(t *testing.T) { } } +func TestTM(t *testing.T) { + // The precision of sqlite is not implemented + if dORM.Driver().Type() == 2 { + return + } + var recTM TM + tm := NewTM() + tm.TMPrecision1 = time.Unix(1596766024, 123456789) + tm.TMPrecision2 = time.Unix(1596766024, 123456789) + _, err := dORM.Insert(tm) + throwFail(t, err) + + err = dORM.QueryTable("tm").One(&recTM) + throwFail(t, err) + throwFail(t, AssertIs(recTM.TMPrecision1.String(), "2020-08-07 02:07:04.123 +0000 UTC")) + throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} From 9472cba6c922e4219a4dc6847d7a32531b1f6067 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 26 Aug 2020 04:12:30 +0000 Subject: [PATCH 120/207] Fix UT --- pkg/client/orm/db.go | 2 +- pkg/client/orm/db_sqlite.go | 37 ++++++++++++++++++----------------- pkg/client/orm/models_test.go | 3 ++- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pkg/client/orm/db.go b/pkg/client/orm/db.go index de56e229..820435ca 100644 --- a/pkg/client/orm/db.go +++ b/pkg/client/orm/db.go @@ -1361,7 +1361,7 @@ setValue: for i := 0; i < *fi.timePrecision; i++ { layout += "0" } - t, err = time.ParseInLocation(layout, s, tz) + t, err = time.ParseInLocation(layout, s[:20+*fi.timePrecision], tz) } else if len(s) >= 19 { s = s[:19] t, err = time.ParseInLocation(formatDateTime, s, tz) diff --git a/pkg/client/orm/db_sqlite.go b/pkg/client/orm/db_sqlite.go index 8cb936be..6d7a5617 100644 --- a/pkg/client/orm/db_sqlite.go +++ b/pkg/client/orm/db_sqlite.go @@ -44,24 +44,25 @@ var sqliteOperators = map[string]string{ // sqlite column types. var sqliteTypes = map[string]string{ - "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "character(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "real", - "float64-decimal": "decimal", + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "time.Time-precision": "datetime(%d)", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", } // sqlite dbBaser. diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index b217dde4..e74f92bb 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -506,7 +506,8 @@ var ( ) func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() + // Debug, _ = StrTo(DBARGS.Debug).Bool() + Debug = true if DBARGS.Driver == "" || DBARGS.Source == "" { fmt.Println(helpinfo) From c2361170b30ec4f2060e59683ad4e32998de1605 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 26 Aug 2020 03:46:22 +0000 Subject: [PATCH 121/207] Support etcd --- .travis.yml | 36 ++- go.mod | 19 +- go.sum | 58 +++++ pkg/client/orm/cmd_utils.go | 2 +- pkg/client/orm/models_test.go | 2 +- pkg/infrastructure/config/base_config_test.go | 3 +- pkg/infrastructure/config/config.go | 68 ++++-- pkg/infrastructure/config/etcd/config.go | 219 ++++++++++++++++++ pkg/infrastructure/config/etcd/config_test.go | 123 ++++++++++ pkg/infrastructure/config/fake.go | 35 +-- pkg/infrastructure/config/ini.go | 111 +-------- pkg/infrastructure/config/ini_test.go | 7 +- pkg/infrastructure/config/json/json.go | 34 +-- pkg/infrastructure/config/json/json_test.go | 10 +- pkg/infrastructure/config/xml/xml.go | 35 +-- pkg/infrastructure/config/xml/xml_test.go | 10 +- pkg/infrastructure/config/yaml/yaml.go | 31 +-- pkg/infrastructure/config/yaml/yaml_test.go | 10 +- scripts/prepare_etcd.sh | 7 + scripts/test_docker_compose.yaml | 18 +- 20 files changed, 581 insertions(+), 257 deletions(-) create mode 100644 pkg/infrastructure/config/etcd/config.go create mode 100644 pkg/infrastructure/config/etcd/config_test.go create mode 100644 scripts/prepare_etcd.sh diff --git a/.travis.yml b/.travis.yml index 63b31c52..f3f1b576 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,24 +7,37 @@ services: - mysql - postgresql - memcached + - etcd env: global: - GO_REPO_FULLNAME="github.com/astaxie/beego" matrix: - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + - ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" before_install: - # link the local repo with ${GOPATH}/src// - - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} - # relies on GOPATH to contain only one directory... - - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} - - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} - - cd ${GOPATH}/src/${GO_REPO_FULLNAME} - # get and build ssdb - - git clone git://github.com/ideawu/ssdb.git - - cd ssdb - - make - - cd .. + # link the local repo with ${GOPATH}/src// + - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} + # relies on GOPATH to contain only one directory... + - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} + - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} + - cd ${GOPATH}/src/${GO_REPO_FULLNAME} + # get and build ssdb + - git clone git://github.com/ideawu/ssdb.git + - cd ssdb + - make + - cd .. + # - prepare for etcd unit tests + - git clone https://github.com/etcd-io/etcd.git + - cd etcd + - ./build + - ./bin/etcd + - ./bin/etcdctl put current.float 1.23 + - ./bin/etcdctl put current.bool true + - ./bin/etcdctl put current.int 11 + - ./bin/etcdctl put current.string hello + - ./bin/etcdctl put current.serialize.name test + - cd .. install: - go get github.com/lib/pq - go get github.com/go-sql-driver/mysql @@ -52,6 +65,7 @@ install: - go get -u github.com/go-redis/redis before_script: - psql --version + # - prepare for orm unit tests - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" diff --git a/go.mod b/go.mod index 91bd9aef..ab7f5e39 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,10 @@ require ( github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 github.com/casbin/casbin v1.7.0 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 + github.com/coreos/etcd v3.3.25+incompatible + github.com/coreos/go-semver v0.3.0 // indirect + github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect + github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 // indirect github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect @@ -16,13 +20,17 @@ require ( github.com/go-redis/redis v6.14.2+incompatible github.com/go-redis/redis/v7 v7.4.0 github.com/go-sql-driver/mysql v1.5.0 - github.com/gogo/protobuf v1.1.1 + github.com/gogo/protobuf v1.3.1 github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible + github.com/google/go-cmp v0.5.0 // indirect + github.com/google/uuid v1.1.1 // indirect + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/mitchellh/mapstructure v1.3.3 github.com/opentracing/opentracing-go v1.2.0 github.com/pelletier/go-toml v1.2.0 // indirect github.com/pkg/errors v0.9.1 @@ -32,9 +40,14 @@ require ( github.com/stretchr/testify v1.4.0 github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect - golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 + go.etcd.io/etcd v3.3.25+incompatible // indirect + go.uber.org/zap v1.15.0 // indirect + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect + golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect + golang.org/x/text v0.3.3 // indirect golang.org/x/tools v0.0.0-20200117065230-39095c1d176c - google.golang.org/grpc v1.31.0 // indirect + google.golang.org/grpc v1.26.0 gopkg.in/yaml.v2 v2.2.8 honnef.co/go/tools v0.0.1-2020.1.5 // indirect ) diff --git a/go.sum b/go.sum index 95babc92..545dbae5 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,15 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/coreos/etcd v0.5.0-alpha.5 h1:0Qi6Jzjk2CDuuGlIeecpu+em2nrjhOgz2wsIwCmQHmc= +github.com/coreos/etcd v3.3.25+incompatible h1:0GQEw6h3YnuOVdtwygkIfJ+Omx0tZ8/QkVyXI4LkbeY= +github.com/coreos/etcd v3.3.25+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d h1:OMrhQqj1QCyDT2sxHCDjE+k8aMdn2ngTCGG7g4wrdLo= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= @@ -46,6 +55,7 @@ github.com/elastic/go-elasticsearch/v6 v6.8.5/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= @@ -69,6 +79,8 @@ github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAf github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -80,6 +92,7 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -92,8 +105,13 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= @@ -101,6 +119,7 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= @@ -117,6 +136,8 @@ github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJK github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mitchellh/mapstructure v1.3.3 h1:SzB1nHZ2Xi+17FP0zVQBHIZqvwRN9408fJO8h+eeNA8= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -187,6 +208,16 @@ github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUM github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= +go.etcd.io/etcd v0.5.0-alpha.5 h1:VOolFSo3XgsmnYDLozjvZ6JL6AAwIDu1Yx1y+4EYLDo= +go.etcd.io/etcd v3.3.25+incompatible h1:V1RzkZJj9LqsJRy+TUBgpWSbZXITLB819lstuTFoZOY= +go.etcd.io/etcd v3.3.25+incompatible/go.mod h1:yaeTdrJi5lOmYerz05bd8+V7KubZs8YSFZfzsF9A6aI= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM= +go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -198,6 +229,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= @@ -216,6 +248,8 @@ golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7 golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -236,15 +270,23 @@ golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= +golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= @@ -260,19 +302,34 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0 h1:2dTRdpdFEEhJYQD8EMLB61nnrzSCTbG38PhqdhvOltg= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.1 h1:SfXqXS5hkufcdZ/mHtYCh53P2b+92WQq/DZcKLgsFRs= +google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -293,5 +350,6 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go index f6b25e8d..e045e847 100644 --- a/pkg/client/orm/cmd_utils.go +++ b/pkg/client/orm/cmd_utils.go @@ -69,7 +69,7 @@ checkColumn: // the precision of sqlite is not implemented if al.Driver == 2 || fi.timePrecision == nil { col = T["time.Time"] - }else { + } else { s := T["time.Time-precision"] col = fmt.Sprintf(s, *fi.timePrecision) } diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 236206c4..81ba30df 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -243,7 +243,7 @@ type UserBig struct { } type TM struct { - ID int `orm:"column(id)"` + ID int `orm:"column(id)"` TMPrecision1 time.Time `orm:"type(datetime);precision(3)"` TMPrecision2 time.Time `orm:"auto_now_add;type(datetime);precision(4)"` } diff --git a/pkg/infrastructure/config/base_config_test.go b/pkg/infrastructure/config/base_config_test.go index 3d37bc91..74a669a7 100644 --- a/pkg/infrastructure/config/base_config_test.go +++ b/pkg/infrastructure/config/base_config_test.go @@ -15,6 +15,7 @@ package config import ( + "context" "errors" "testing" @@ -59,7 +60,7 @@ func TestBaseConfiger_DefaultStrings(t *testing.T) { func newBaseConfier(str1 string) *BaseConfiger { return &BaseConfiger{ - reader: func(key string) (string, error) { + reader: func(ctx context.Context, key string) (string, error) { if key == "key1" { return str1, nil } else { diff --git a/pkg/infrastructure/config/config.go b/pkg/infrastructure/config/config.go index b17f6208..3514e425 100644 --- a/pkg/infrastructure/config/config.go +++ b/pkg/infrastructure/config/config.go @@ -41,6 +41,7 @@ package config import ( + "context" "errors" "fmt" "os" @@ -56,9 +57,9 @@ type Configer interface { Set(key, val string) error // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - String(key string) string + String(key string) (string, error) // get string slice - Strings(key string) []string + Strings(key string) ([]string, error) Int(key string) (int, error) Int64(key string) (int64, error) Bool(key string) (bool, error) @@ -72,11 +73,13 @@ type Configer interface { DefaultBool(key string, defaultVal bool) bool DefaultFloat(key string, defaultVal float64) float64 DIY(key string) (interface{}, error) - GetSection(section string) (map[string]string, error) - Unmarshaler(obj interface{}) error + GetSection(section string) (map[string]string, error) + GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) + + Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error Sub(key string) (Configer, error) - OnChange(fn func(cfg Configer)) + OnChange(ctx context.Context, key string, fn func(value string)) // GetByPrefix(prefix string) ([]byte, error) // GetSerializer() Serializer SaveConfigFile(filename string) error @@ -84,11 +87,17 @@ type Configer interface { type BaseConfiger struct { // The reader should support key like "a.b.c" - reader func(key string) (string, error) + reader func(ctx context.Context, key string) (string, error) +} + +func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error)) BaseConfiger { + return BaseConfiger{ + reader: reader, + } } func (c *BaseConfiger) Int(key string) (int, error) { - res, err := c.reader(key) + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } @@ -96,7 +105,7 @@ func (c *BaseConfiger) Int(key string) (int, error) { } func (c *BaseConfiger) Int64(key string) (int64, error) { - res, err := c.reader(key) + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } @@ -104,30 +113,34 @@ func (c *BaseConfiger) Int64(key string) (int64, error) { } func (c *BaseConfiger) Bool(key string) (bool, error) { - res, err := c.reader(key) + res, err := c.reader(context.TODO(), key) if err != nil { return false, err } - return strconv.ParseBool(res) + return ParseBool(res) } func (c *BaseConfiger) Float(key string) (float64, error) { - res, err := c.reader(key) + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } return strconv.ParseFloat(res, 64) } +// DefaultString returns the string value for a given key. +// if err != nil or value is empty return defaultval func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { - if res := c.String(key); res != "" { + if res, err := c.String(key); res != "" && err != nil { return res } return defaultVal } +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { - if res := c.Strings(key); len(res) > 0 { + if res, err := c.Strings(key); len(res) > 0 && err != nil { return res } return defaultVal @@ -160,21 +173,27 @@ func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { return defaultVal } -func (c *BaseConfiger) String(key string) string { - res, _ := c.reader(key) - return res +func (c *BaseConfiger) GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) { + // TODO + return nil, nil } -func (c *BaseConfiger) Strings(key string) []string { - res, err := c.reader(key) +func (c *BaseConfiger) String(key string) (string, error) { + return c.reader(context.TODO(), key) +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *BaseConfiger) Strings(key string) ([]string, error) { + res, err := c.String(key) if err != nil || res == "" { - return nil + return nil, err } - return strings.Split(res, ";") + return strings.Split(res, ";"), nil } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) Unmarshaler(obj interface{}) error { +func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error { return errors.New("unsupported operation") } @@ -184,7 +203,7 @@ func (c *BaseConfiger) Sub(key string) (Configer, error) { } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) OnChange(fn func(cfg Configer)) { +func (c *BaseConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { // do nothing } @@ -361,3 +380,8 @@ func ToString(x interface{}) string { // Fallback to fmt package for anything else like numeric types return fmt.Sprint(x) } + +type DecodeOption func(options decodeOptions) + +type decodeOptions struct { +} diff --git a/pkg/infrastructure/config/etcd/config.go b/pkg/infrastructure/config/etcd/config.go new file mode 100644 index 00000000..30f26ce1 --- /dev/null +++ b/pkg/infrastructure/config/etcd/config.go @@ -0,0 +1,219 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/coreos/etcd/clientv3" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "google.golang.org/grpc" + + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +const etcdOpts = "etcdOpts" + +type EtcdConfiger struct { + prefix string + client *clientv3.Client + config.BaseConfiger +} + +func newEtcdConfiger(client *clientv3.Client, prefix string) *EtcdConfiger { + res := &EtcdConfiger{ + client: client, + prefix: prefix, + } + + res.BaseConfiger = config.NewBaseConfiger(res.reader) + return res +} + +// reader is an general implementation that read config from etcd. +func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { + resp, err := get(e.client, ctx, e.prefix+key) + if err != nil { + return "", err + } + + if resp.Count > 0 { + return string(resp.Kvs[0].Value), nil + } + + return "", nil +} + +// Set do nothing and return an error +// I think write data to remote config center is not a good practice +func (e *EtcdConfiger) Set(key, val string) error { + return errors.New("Unsupported operation") +} + +// DIY return the original response from etcd +// be careful when you decide to use this +func (e *EtcdConfiger) DIY(key string) (interface{}, error) { + return get(e.client, context.TODO(), key) +} + +// GetSection in this implementation, we use section as prefix +func (e *EtcdConfiger) GetSection(section string) (map[string]string, error) { + return e.GetSectionWithCtx(context.Background(), section) +} + +func (e *EtcdConfiger) GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) { + + var ( + resp *clientv3.GetResponse + err error + ) + + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + opts = append(opts, clientv3.WithPrefix()) + resp, err = e.client.Get(context.TODO(), e.prefix+section, opts...) + } else { + resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix()) + } + + if err != nil { + return nil, errors.WithMessage(err, "GetSection failed") + } + res := make(map[string]string, len(resp.Kvs)) + for _, kv := range resp.Kvs { + res[string(kv.Key)] = string(kv.Value) + } + return res, nil +} + +func (e *EtcdConfiger) SaveConfigFile(filename string) error { + return errors.New("Unsupported operation") +} + +// Unmarshaler is not very powerful because we lost the type information when we get configuration from etcd +// for example, when we got "5", we are not sure whether it's int 5, or it's string "5" +// TODO(support more complicated decoder) +func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + res, err := e.GetSectionWithCtx(ctx, prefix) + if err != nil { + return errors.WithMessage(err, fmt.Sprintf("could not read config with prefix: %s", prefix)) + } + + prefixLen := len(e.prefix + prefix) + m := make(map[string]string, len(res)) + for k, v := range res { + m[k[prefixLen:]] = v + } + return mapstructure.Decode(m, obj) +} + +// Sub return an sub configer. +func (e *EtcdConfiger) Sub(key string) (config.Configer, error) { + return newEtcdConfiger(e.client, e.prefix+key), nil +} + +// TODO remove this before release v2.0.0 +func (e *EtcdConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { + + buildOptsFunc := func() []clientv3.OpOption { + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + opts = append(opts, clientv3.WithCreatedNotify()) + return opts + } + return []clientv3.OpOption{} + } + + rch := e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + go func() { + for { + for resp := range rch { + if err := resp.Err(); err != nil { + logs.Error("listen to key but got error callback", err) + break + } + + for _, e := range resp.Events { + if e.Kv == nil { + continue + } + fn(string(e.Kv.Value)) + } + } + time.Sleep(time.Second) + rch = e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + } + }() + +} + +type EtcdConfigerProvider struct { +} + +// Parse = ParseData([]byte(key)) +// key must be json +func (provider *EtcdConfigerProvider) Parse(key string) (config.Configer, error) { + return provider.ParseData([]byte(key)) +} + +// ParseData try to parse key as clientv3.Config, using this to build etcdClient +func (provider *EtcdConfigerProvider) ParseData(data []byte) (config.Configer, error) { + cfg := &clientv3.Config{} + err := json.Unmarshal(data, cfg) + if err != nil { + return nil, errors.WithMessage(err, "parse data to etcd config failed, please check your input") + } + + cfg.DialOptions = []grpc.DialOption{ + grpc.WithBlock(), + grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor), + grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor), + } + client, err := clientv3.New(*cfg) + if err != nil { + return nil, errors.WithMessage(err, "create etcd client failed") + } + + return newEtcdConfiger(client, ""), nil +} + +func get(client *clientv3.Client, ctx context.Context, key string) (*clientv3.GetResponse, error) { + var ( + resp *clientv3.GetResponse + err error + ) + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + resp, err = client.Get(ctx, key, opts...) + } else { + resp, err = client.Get(ctx, key) + } + + if err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("read config from etcd with key %s failed", key)) + } + return resp, err +} + +func WithEtcdOption(ctx context.Context, opts ...clientv3.OpOption) context.Context { + return context.WithValue(ctx, etcdOpts, opts) +} + +func init() { + config.Register("json", &EtcdConfigerProvider{}) +} diff --git a/pkg/infrastructure/config/etcd/config_test.go b/pkg/infrastructure/config/etcd/config_test.go new file mode 100644 index 00000000..a9cadd95 --- /dev/null +++ b/pkg/infrastructure/config/etcd/config_test.go @@ -0,0 +1,123 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/stretchr/testify/assert" +) + +func TestWithEtcdOption(t *testing.T) { + ctx := WithEtcdOption(context.Background(), clientv3.WithPrefix()) + assert.NotNil(t, ctx.Value(etcdOpts)) +} + +func TestEtcdConfigerProvider_Parse(t *testing.T) { + provider := &EtcdConfigerProvider{} + cfger, err := provider.Parse(readEtcdConfig()) + assert.Nil(t, err) + assert.NotNil(t, cfger) +} + +func TestEtcdConfiger(t *testing.T) { + + provider := &EtcdConfigerProvider{} + cfger, _ := provider.Parse(readEtcdConfig()) + + subCfger, err := cfger.Sub("sub.") + assert.Nil(t, err) + assert.NotNil(t, subCfger) + + subSubCfger, err := subCfger.Sub("sub.") + assert.NotNil(t, subSubCfger) + assert.Nil(t, err) + + str, err := subSubCfger.String("key1") + assert.Nil(t, err) + assert.Equal(t, "sub.sub.key", str) + + // we cannot test it + subSubCfger.OnChange(context.Background(), "watch", func(value string) { + // do nothing + }) + + defStr := cfger.DefaultString("not_exit", "default value") + assert.Equal(t, "default value", defStr) + + defInt64 := cfger.DefaultInt64("not_exit", -1) + assert.Equal(t, int64(-1), defInt64) + + defInt := cfger.DefaultInt("not_exit", -2) + assert.Equal(t, -2, defInt) + + defFlt := cfger.DefaultFloat("not_exit", 12.3) + assert.Equal(t, 12.3, defFlt) + + defBl := cfger.DefaultBool("not_exit", true) + assert.True(t, defBl) + + defStrs := cfger.DefaultStrings("not_exit", []string{"hello"}) + assert.Equal(t, []string{"hello"}, defStrs) + + fl, err := cfger.Float("current.float") + assert.Nil(t, err) + assert.Equal(t, 1.23, fl) + + bl, err := cfger.Bool("current.bool") + assert.Nil(t, err) + assert.True(t, bl) + + it, err := cfger.Int("current.int") + assert.Nil(t, err) + assert.Equal(t, 11, it) + + str, err = cfger.String("current.string") + assert.Nil(t, err) + assert.Equal(t, "hello", str) + + tn := &TestEntity{} + err = cfger.Unmarshaler(context.Background(), "current.serialize.", tn) + assert.Nil(t, err) + assert.Equal(t, "test", tn.Name) +} + +type TestEntity struct { + Name string `yaml:"name"` + Sub SubEntity `yaml:"sub"` +} + +type SubEntity struct { + SubName string `yaml:"subName"` +} + +func readEtcdConfig() string { + addr := os.Getenv("ETCD_ADDR") + if addr == "" { + addr = "localhost:2379" + } + + obj := clientv3.Config{ + Endpoints: []string{addr}, + DialTimeout: 3 * time.Second, + } + cfg, _ := json.Marshal(obj) + return string(cfg) +} diff --git a/pkg/infrastructure/config/fake.go b/pkg/infrastructure/config/fake.go index ddbc99b8..f885d44d 100644 --- a/pkg/infrastructure/config/fake.go +++ b/pkg/infrastructure/config/fake.go @@ -15,6 +15,7 @@ package config import ( + "context" "errors" "strconv" "strings" @@ -34,34 +35,6 @@ func (c *fakeConfigContainer) Set(key, val string) error { return nil } -func (c *fakeConfigContainer) String(key string) string { - return c.getData(key) -} - -func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -func (c *fakeConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } @@ -129,7 +102,11 @@ var _ Configer = new(fakeConfigContainer) // NewFakeConfig return a fake Configer func NewFakeConfig() Configer { - return &fakeConfigContainer{ + res := &fakeConfigContainer{ data: make(map[string]string), } + res.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return res.getData(key), nil + }) + return res } diff --git a/pkg/infrastructure/config/ini.go b/pkg/infrastructure/config/ini.go index 0bef67d4..4d3946d5 100644 --- a/pkg/infrastructure/config/ini.go +++ b/pkg/infrastructure/config/ini.go @@ -17,13 +17,14 @@ package config import ( "bufio" "bytes" + "context" "errors" + "fmt" "io" "io/ioutil" "os" "os/user" "path/filepath" - "strconv" "strings" "sync" ) @@ -65,6 +66,9 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e keyComment: make(map[string]string), RWMutex: sync.RWMutex{}, } + cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return cfg.getdata(key) + }) cfg.Lock() defer cfg.Unlock() @@ -90,7 +94,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e break } - //It might be a good idea to throw a error on all unknonw errors? + // It might be a good idea to throw a error on all unknonw errors? if _, ok := err.(*os.PathError); ok { return nil, err } @@ -232,101 +236,6 @@ type IniConfigContainer struct { sync.RWMutex } -// Bool returns the boolean value for a given key. -func (c *IniConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getdata(key)) -} - -// DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -func (c *IniConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.getdata(key)) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -func (c *IniConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.getdata(key), 10, 64) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Float returns the float value for a given key. -func (c *IniConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.getdata(key), 64) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -func (c *IniConfigContainer) String(key string) string { - return c.getdata(key) -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Return nil if config value does not exist or is empty. -func (c *IniConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - // GetSection returns map for the given section func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { @@ -474,9 +383,9 @@ func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { } // section.key or key -func (c *IniConfigContainer) getdata(key string) string { +func (c *IniConfigContainer) getdata(key string) (string, error) { if len(key) == 0 { - return "" + return "", errors.New("the key is empty") } c.RLock() defer c.RUnlock() @@ -494,10 +403,10 @@ func (c *IniConfigContainer) getdata(key string) string { } if v, ok := c.data[section]; ok { if vv, ok := v[k]; ok { - return vv + return vv, nil } } - return "" + return "", errors.New(fmt.Sprintf("config not found: %s", key)) } func init() { diff --git a/pkg/infrastructure/config/ini_test.go b/pkg/infrastructure/config/ini_test.go index ffcdb294..7daa0a6e 100644 --- a/pkg/infrastructure/config/ini_test.go +++ b/pkg/infrastructure/config/ini_test.go @@ -109,9 +109,9 @@ password = ${GOPATH} case bool: value, err = iniconf.Bool(k) case []string: - value = iniconf.Strings(k) + value, err = iniconf.Strings(k) case string: - value = iniconf.String(k) + value, err = iniconf.String(k) default: value, err = iniconf.DIY(k) } @@ -125,7 +125,8 @@ password = ${GOPATH} if err = iniconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if iniconf.String("name") != "astaxie" { + res, _ := iniconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } diff --git a/pkg/infrastructure/config/json/json.go b/pkg/infrastructure/config/json/json.go index bd28411f..b552269a 100644 --- a/pkg/infrastructure/config/json/json.go +++ b/pkg/infrastructure/config/json/json.go @@ -158,42 +158,14 @@ func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float } // String returns the string value for a given key. -func (c *JSONConfigContainer) String(key string) string { +func (c *JSONConfigContainer) String(key string) (string, error) { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { - return v + return v, nil } } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { - // TODO FIXME should not use "" to replace non existence - if v := c.String(key); v != "" { - return v - } - return defaultval -} - -// Strings returns the []string value for a given key. -func (c *JSONConfigContainer) Strings(key string) []string { - stringVal := c.String(key) - if stringVal == "" { - return nil - } - return strings.Split(c.String(key), ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); v != nil { - return v - } - return defaultval + return "", errors.New(fmt.Sprintf("config not found or is not string, key: %s", key)) } // GetSection returns map for the given section diff --git a/pkg/infrastructure/config/json/json_test.go b/pkg/infrastructure/config/json/json_test.go index 75a42145..cf337d20 100644 --- a/pkg/infrastructure/config/json/json_test.go +++ b/pkg/infrastructure/config/json/json_test.go @@ -163,9 +163,9 @@ func TestJson(t *testing.T) { case bool: value, err = jsonconf.Bool(k) case []string: - value = jsonconf.Strings(k) + value, err = jsonconf.Strings(k) case string: - value = jsonconf.String(k) + value, err = jsonconf.String(k) default: value, err = jsonconf.DIY(k) } @@ -179,7 +179,9 @@ func TestJson(t *testing.T) { if err = jsonconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if jsonconf.String("name") != "astaxie" { + + res, _ := jsonconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } @@ -210,7 +212,7 @@ func TestJson(t *testing.T) { t.Error("unknown keys should return an error when expecting an interface{}") } - if val := jsonconf.String("unknown"); val != "" { + if val, _ := jsonconf.String("unknown"); val != "" { t.Error("unknown keys should return an empty string when expecting a String") } diff --git a/pkg/infrastructure/config/xml/xml.go b/pkg/infrastructure/config/xml/xml.go index 3413e0a5..c095ef06 100644 --- a/pkg/infrastructure/config/xml/xml.go +++ b/pkg/infrastructure/config/xml/xml.go @@ -26,7 +26,7 @@ // // cnf, err := config.NewConfig("xml", "config.xml") // -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package xml import ( @@ -36,11 +36,11 @@ import ( "io/ioutil" "os" "strconv" - "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/config" "github.com/beego/x2j" + + "github.com/astaxie/beego/pkg/infrastructure/config" ) // Config is a xml config parser and implements Config interface. @@ -144,37 +144,18 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) (string, error) { if v, ok := c.data[key].(string); ok { - return v + return v, nil } - return "" + return "", errors.New(fmt.Sprintf("configuration not found or not string, key: %s", key)) } // DefaultString returns the string value for a given key. // if err != nil return defaultval func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { + v, err := c.String(key) + if v == "" || err != nil { return defaultval } return v diff --git a/pkg/infrastructure/config/xml/xml_test.go b/pkg/infrastructure/config/xml/xml_test.go index 4cd0df1f..0391efab 100644 --- a/pkg/infrastructure/config/xml/xml_test.go +++ b/pkg/infrastructure/config/xml/xml_test.go @@ -25,7 +25,7 @@ import ( func TestXML(t *testing.T) { var ( - //xml parse should incluce in tags + // xml parse should incluce in tags xmlcontext = ` beeapi @@ -102,9 +102,9 @@ func TestXML(t *testing.T) { case bool: value, err = xmlconf.Bool(k) case []string: - value = xmlconf.Strings(k) + value, err = xmlconf.Strings(k) case string: - value = xmlconf.String(k) + value, err = xmlconf.String(k) default: value, err = xmlconf.DIY(k) } @@ -119,7 +119,9 @@ func TestXML(t *testing.T) { if err = xmlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if xmlconf.String("name") != "astaxie" { + + res, _ := xmlconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } } diff --git a/pkg/infrastructure/config/yaml/yaml.go b/pkg/infrastructure/config/yaml/yaml.go index 9a3b698a..96045365 100644 --- a/pkg/infrastructure/config/yaml/yaml.go +++ b/pkg/infrastructure/config/yaml/yaml.go @@ -26,7 +26,7 @@ // // cnf, err := config.NewConfig("yaml", "config.yaml") // -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package yaml import ( @@ -40,8 +40,9 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/config" "github.com/beego/goyaml2" + + "github.com/astaxie/beego/pkg/infrastructure/config" ) // Config is a yaml config parser and implements Config interface. @@ -209,39 +210,41 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { - return vv + return vv, nil + } else { + return "", errors.New(fmt.Sprintf("the value is not string, key: %s, value: %v", key, v)) } } - return "" + return "", errors.New(fmt.Sprintf("configuration not found, key: %s", key)) } // DefaultString returns the string value for a given key. // if err != nil return defaultval func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { + v, err := c.String(key) + if v == "" || err != nil { return defaultval } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { + v, err := c.Strings(key) + if v == nil || err != nil { return defaultval } return v diff --git a/pkg/infrastructure/config/yaml/yaml_test.go b/pkg/infrastructure/config/yaml/yaml_test.go index 2437d6c7..0fa8bc7b 100644 --- a/pkg/infrastructure/config/yaml/yaml_test.go +++ b/pkg/infrastructure/config/yaml/yaml_test.go @@ -70,7 +70,8 @@ func TestYaml(t *testing.T) { t.Fatal(err) } - if yamlconf.String("appname") != "beeapi" { + res, _ := yamlconf.String("appname") + if res != "beeapi" { t.Fatal("appname not equal to beeapi") } @@ -91,9 +92,9 @@ func TestYaml(t *testing.T) { case bool: value, err = yamlconf.Bool(k) case []string: - value = yamlconf.Strings(k) + value, err = yamlconf.Strings(k) case string: - value = yamlconf.String(k) + value, err = yamlconf.String(k) default: value, err = yamlconf.DIY(k) } @@ -108,7 +109,8 @@ func TestYaml(t *testing.T) { if err = yamlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if yamlconf.String("name") != "astaxie" { + res, _ = yamlconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh new file mode 100644 index 00000000..a65f00a3 --- /dev/null +++ b/scripts/prepare_etcd.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +etcdctl put current.float 1.23 +etcdctl put current.bool true +etcdctl put current.int 11 +etcdctl put current.string hello +etcdctl put current.serialize.name test \ No newline at end of file diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml index 54ca4097..f22b6deb 100644 --- a/scripts/test_docker_compose.yaml +++ b/scripts/test_docker_compose.yaml @@ -36,4 +36,20 @@ services: image: memcached ports: - "11211:11211" - + etcd: + command: > + sh -c " + etcdctl put current.float 1.23 + && etcdctl put current.bool true + && etcdctl put current.int 11 + && etcdctl put current.string hello + && etcdctl put current.serialize.name test + " + container_name: "beego-etcd" + environment: + - ALLOW_NONE_AUTHENTICATION=yes +# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 + image: bitnami/etcd + ports: + - "2379:2379" + - "2380:2380" \ No newline at end of file From 2b39ff78374f3b99c4c874b7cf71b1f50e058e7e Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:00:45 +0100 Subject: [PATCH 122/207] New opts formatter working for console --- pkg/logs/conn.go | 32 ++++++++++++++++---------------- pkg/logs/console.go | 31 +++++++++++++++++++------------ pkg/logs/file.go | 16 +++++++++------- pkg/logs/jianliao.go | 18 ++++++++++-------- pkg/logs/log.go | 36 ++++++++++++++++++++++++++---------- pkg/logs/multifile.go | 22 ++++++++++++---------- pkg/logs/slack.go | 17 +++++++++-------- pkg/logs/smtp.go | 18 ++++++++++-------- 8 files changed, 111 insertions(+), 79 deletions(-) diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index e11909a0..55cbecdd 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -18,20 +18,20 @@ import ( "encoding/json" "io" "net" + + "github.com/astaxie/beego/pkg/common" ) // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -47,13 +47,13 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem - } - } +func (c *connWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // c.UseCustomFormatter = true + // c.CustomFormatter = elem + // } + // } return json.Unmarshal([]byte(jsonConfig), c) } diff --git a/pkg/logs/console.go b/pkg/logs/console.go index a928de7d..55958008 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -19,6 +19,8 @@ import ( "os" "strings" + "github.com/astaxie/beego/pkg/common" + "github.com/shiena/ansicolor" ) @@ -47,11 +49,10 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + customFormatter func(*LogMsg) string + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { @@ -80,11 +81,16 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem +// func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { +func (c *consoleWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter } } @@ -107,10 +113,11 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { msg := "" - if c.UseCustomFormatter { - msg = c.CustomFormatter(lm) + if c.customFormatter != nil { + msg = c.customFormatter(lm) } else { msg = c.Format(lm) + } c.lg.writeln(msg) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 4576e19d..0324486e 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -27,6 +27,8 @@ import ( "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/common" ) // fileLogWriter implements LoggerInterface. @@ -107,13 +109,13 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - w.UseCustomFormatter = true - w.CustomFormatter = elem - } - } +func (w *fileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // w.UseCustomFormatter = true + // w.CustomFormatter = elem + // } + // } err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 9877bed6..8daa8015 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + + "github.com/astaxie/beego/pkg/common" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -25,15 +27,15 @@ func newJLWriter() Logger { } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *JLWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } func (s *JLWriter) Format(lm *LogMsg) string { diff --git a/pkg/logs/log.go b/pkg/logs/log.go index fd8fca63..9529c865 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -38,10 +38,13 @@ import ( "log" "os" "path" + "reflect" "runtime" "strings" "sync" "time" + + "github.com/astaxie/beego/pkg/common" ) // RFC5424 log message levels. @@ -84,7 +87,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string, LogFormatter ...func(*LogMsg) string) error + Init(config string, opts ...common.SimpleKV) error WriteMsg(lm *LogMsg) error Format(lm *LogMsg) string Destroy() @@ -210,7 +213,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { // Global formatter overrides the default set formatter // but not adapter specific formatters set with logs.SetLoggerWithOpts() if bl.globalFormatter != nil { - err = lg.Init(config, bl.globalFormatter) + err = lg.Init(config) } else { err = lg.Init(config) } @@ -401,9 +404,21 @@ func (bl *BeeLogger) startLogger() { } } +// Get the formatter from the opts common.SimpleKV structure +// Looks for a key: "formatter" with value: func(*LogMsg) string +func GetFormatter(opts common.SimpleKV) (func(*LogMsg) string, error) { + if strings.ToLower(opts.Key.(string)) == "formatter" { + formatterInterface := reflect.ValueOf(opts.Value).Interface() + formatterFunc := formatterInterface.(func(*LogMsg) string) + return formatterFunc, nil + } + + return nil, fmt.Errorf("no \"formatter\" key given in simpleKV") +} + // SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON // such as: {"interval":360} -func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { +func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { config := append(configs, "{}")[0] for _, l := range bl.outputs { if l.name == adapterName { @@ -416,12 +431,12 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*L return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } - if formatterFunc == nil { - return fmt.Errorf("No formatter set for %s log adapter", adapterName) + if opts.Key == nil { + return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName) } lg := logAdapter() - err := lg.Init(config, formatterFunc) + err := lg.Init(config, opts) if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -436,21 +451,22 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, formatterFunc func(*L } // SetLogger provides a given logger adapter into BeeLogger with config string. -func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, formatterFunc func(*LogMsg) string, configs ...string) error { +func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { bl.lock.Lock() defer bl.lock.Unlock() if !bl.init { bl.outputs = []*nameLogger{} bl.init = true } - return bl.setLoggerWithOpts(adapterName, formatterFunc, configs...) + return bl.setLoggerWithOpts(adapterName, opts, configs...) } // SetLoggerWIthOpts sets a given log adapter with a custom log adapter. // Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} // where FormatFunc has the signature func(*LogMsg) string -func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { - err := beeLogger.SetLoggerWithOpts(adapter, formatterFunc, config...) +// func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { +func SetLoggerWithOpts(adapter string, config []string, opts common.SimpleKV) error { + err := beeLogger.SetLoggerWithOpts(adapter, opts, config...) if err != nil { log.Fatal(err) } diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index bcd4dd4e..720f5125 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -16,6 +16,8 @@ package logs import ( "encoding/json" + + "github.com/astaxie/beego/pkg/common" ) // A filesLogWriter manages several fileLogWriter @@ -46,16 +48,16 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - f.UseCustomFormatter = true - f.CustomFormatter = elem - } - } +func (f *multiFileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // f.UseCustomFormatter = true + // f.CustomFormatter = elem + // } + // } writer := newFileWriter().(*fileLogWriter) - err := writer.Init(config) + err := writer.Init(jsonConfig) if err != nil { return err } @@ -63,10 +65,10 @@ func (f *multiFileLogWriter) Init(config string, LogFormatter ...func(*LogMsg) s f.writers[LevelDebug+1] = writer //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) + json.Unmarshal([]byte(jsonConfig), f) jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) + json.Unmarshal([]byte(jsonConfig), &jsonMap) for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go index 9407b48a..0fc75149 100644 --- a/pkg/logs/slack.go +++ b/pkg/logs/slack.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + + "github.com/astaxie/beego/pkg/common" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -25,15 +27,14 @@ func (s *SLACKWriter) Format(lm *LogMsg) string { } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *SLACKWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } // WriteMsg write message in smtp writer. diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index b81be68f..17148812 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -21,6 +21,8 @@ import ( "net" "net/smtp" "strings" + + "github.com/astaxie/beego/pkg/common" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -52,15 +54,15 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonconfig string, LogFormatter ...func(*LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - s.UseCustomFormatter = true - s.CustomFormatter = elem - } - } +func (s *SMTPWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { + // for _, elem := range LogFormatter { + // if elem != nil { + // s.UseCustomFormatter = true + // s.CustomFormatter = elem + // } + // } - return json.Unmarshal([]byte(jsonconfig), s) + return json.Unmarshal([]byte(jsonConfig), s) } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { From 8178f035a08231ec7a04ab7a825ef1cacffac4d4 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:18:28 +0100 Subject: [PATCH 123/207] Custom formatting opts implementation --- pkg/logs/alils/alils.go | 36 ++++++++++++++++++++---------------- pkg/logs/conn.go | 40 ++++++++++++++++++++++++++-------------- pkg/logs/console.go | 1 - pkg/logs/es/es.go | 28 ++++++++++++++++------------ pkg/logs/file.go | 24 ++++++++++++++---------- pkg/logs/jianliao.go | 40 +++++++++++++++++++++++++--------------- pkg/logs/multifile.go | 24 +++++++++++++----------- pkg/logs/smtp.go | 19 +++++++++++-------- 8 files changed, 125 insertions(+), 87 deletions(-) diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go index 183d9b24..425071f8 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/logs/alils/alils.go @@ -5,6 +5,7 @@ import ( "strings" "sync" + "github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/logs" "github.com/gogo/protobuf/proto" ) @@ -32,13 +33,12 @@ type Config struct { // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - UseCustomFormatter bool - CustomFormatter func(*logs.LogMsg) string + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + customFormatter func(*logs.LogMsg) string Config } @@ -50,15 +50,17 @@ func NewAliLS() logs.Logger { } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) string) (err error) { +func (c *aliLSWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - for _, elem := range LogFormatter { - if elem != nil { - c.UseCustomFormatter = true - c.CustomFormatter = elem + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := logs.GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter } } - json.Unmarshal([]byte(jsonConfig), c) if c.FlushWhen > CacheSize { @@ -72,11 +74,13 @@ func (c *aliLSWriter) Init(jsonConfig string, LogFormatter ...func(*logs.LogMsg) AccessKeySecret: c.KeySecret, } - c.store, err = prj.GetLogStore(c.LogStore) + store, err := prj.GetLogStore(c.LogStore) if err != nil { return err } + c.store = store + // Create default Log Group c.group = append(c.group, &LogGroup{ Topic: proto.String(""), @@ -141,8 +145,8 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } - if c.UseCustomFormatter { - content = c.CustomFormatter(lm) + if c.customFormatter != nil { + content = c.customFormatter(lm) } else { content = c.Format(lm) } diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go index 55cbecdd..9a520bda 100644 --- a/pkg/logs/conn.go +++ b/pkg/logs/conn.go @@ -25,13 +25,14 @@ import ( // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + customFormatter func(*LogMsg) string + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. @@ -48,12 +49,16 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key func (c *connWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // c.UseCustomFormatter = true - // c.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + c.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), c) } @@ -75,7 +80,14 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - msg := c.Format(lm) + msg := "" + if c.customFormatter != nil { + msg = c.customFormatter(lm) + } else { + msg = c.Format(lm) + + } + _, err := c.lg.writeln(msg) if err != nil { return err diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 55958008..34114e4a 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -81,7 +81,6 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -// func (c *consoleWriter) Init(jsonConfig string, LogFormatter ...func(*LogMsg) string) error { func (c *consoleWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { for _, elem := range opts { diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go index 4dfc4160..dc9304c8 100644 --- a/pkg/logs/es/es.go +++ b/pkg/logs/es/es.go @@ -12,6 +12,7 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" + "github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/logs" ) @@ -31,10 +32,9 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*logs.LogMsg) string + DSN string `json:"dsn"` + Level int `json:"level"` + customFormatter func(*logs.LogMsg) string } func (el *esLogger) Format(lm *logs.LogMsg) string { @@ -42,15 +42,19 @@ func (el *esLogger) Format(lm *logs.LogMsg) string { } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string, LogFormatter ...func(*logs.LogMsg) string) error { - for _, elem := range LogFormatter { - if elem != nil { - el.UseCustomFormatter = true - el.CustomFormatter = elem +func (el *esLogger) Init(jsonConfig string, opts ...common.SimpleKV) error { + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := logs.GetFormatter(elem) + if err != nil { + return err + } + el.customFormatter = formatter } } - err := json.Unmarshal([]byte(jsonconfig), el) + err := json.Unmarshal([]byte(jsonConfig), el) if err != nil { return err } @@ -79,8 +83,8 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { } msg := "" - if el.UseCustomFormatter { - msg = el.CustomFormatter(lm) + if el.customFormatter != nil { + msg = el.customFormatter(lm) } else { msg = el.Format(lm) } diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 0324486e..42148c3a 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -62,8 +62,7 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + customFormatter func(*LogMsg) string Rotate bool `json:"rotate"` @@ -110,12 +109,16 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "perm":"0600" // } func (w *fileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // w.UseCustomFormatter = true - // w.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + w.customFormatter = formatter + } + } err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { @@ -166,8 +169,9 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { } hd, d, h := formatTimeHeader(lm.When) msg := "" - if w.UseCustomFormatter { - msg = w.CustomFormatter(lm) + + if w.customFormatter != nil { + msg = w.customFormatter(lm) } else { msg = w.Format(lm) } diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go index 8daa8015..81d0195b 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/logs/jianliao.go @@ -11,14 +11,13 @@ import ( // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + customFormatter func(*LogMsg) string } // newJLWriter creates jiaoliao writer. @@ -28,12 +27,15 @@ func newJLWriter() Logger { // Init JLWriter with json config string func (s *JLWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + s.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), s) } @@ -49,7 +51,15 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) + text := "" + + if s.customFormatter != nil { + text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.customFormatter(lm)) + } else { + text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) + + } + form := url.Values{} form.Add("authorName", s.AuthorName) form.Add("title", s.Title) diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go index 720f5125..c1b7cfdd 100644 --- a/pkg/logs/multifile.go +++ b/pkg/logs/multifile.go @@ -26,11 +26,10 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` + customFormatter func(*LogMsg) string } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -49,12 +48,15 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // } func (f *multiFileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // f.UseCustomFormatter = true - // f.CustomFormatter = elem - // } - // } + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + f.customFormatter = formatter + } + } writer := newFileWriter().(*fileLogWriter) err := writer.Init(jsonConfig) diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go index 17148812..9b67e343 100644 --- a/pkg/logs/smtp.go +++ b/pkg/logs/smtp.go @@ -34,8 +34,7 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + customFormatter func(*LogMsg) string } // NewSMTPWriter creates the smtp writer. @@ -55,12 +54,16 @@ func newSMTPWriter() Logger { // "level":LevelError // } func (s *SMTPWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { - // for _, elem := range LogFormatter { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } + + for _, elem := range opts { + if elem.Key == "formatter" { + formatter, err := GetFormatter(elem) + if err != nil { + return err + } + s.customFormatter = formatter + } + } return json.Unmarshal([]byte(jsonConfig), s) } From 0189e6329a4e1700ea589de4eebe29e8624b422d Mon Sep 17 00:00:00 2001 From: IamCathal Date: Fri, 28 Aug 2020 18:47:28 +0100 Subject: [PATCH 126/207] Add global logging override --- pkg/logs/log.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/logs/log.go b/pkg/logs/log.go index 9529c865..e18ea95b 100644 --- a/pkg/logs/log.go +++ b/pkg/logs/log.go @@ -213,7 +213,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { // Global formatter overrides the default set formatter // but not adapter specific formatters set with logs.SetLoggerWithOpts() if bl.globalFormatter != nil { - err = lg.Init(config) + err = lg.Init(config, common.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) } else { err = lg.Init(config) } From 81b9a1382a65fd4d663dfbb6b1b5f53d5ad38705 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 29 Aug 2020 01:17:43 +0800 Subject: [PATCH 127/207] Fix UT --- .travis.yml | 42 +++++++--- pkg/client/httplib/testing/client.go | 5 +- pkg/infrastructure/config/config.go | 4 +- pkg/infrastructure/config/ini.go | 111 ++++++++++++++++++++++--- pkg/infrastructure/config/json/json.go | 30 ++++++- pkg/infrastructure/config/xml/xml.go | 27 +++++- pkg/infrastructure/config/yaml/yaml.go | 4 +- pkg/server/web/config.go | 40 ++++----- pkg/server/web/hooks.go | 4 +- pkg/server/web/templatefunc.go | 2 +- scripts/prepare_etcd.sh | 3 +- 11 files changed, 215 insertions(+), 57 deletions(-) diff --git a/.travis.yml b/.travis.yml index f3f1b576..67efe057 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ services: - mysql - postgresql - memcached - - etcd + - docker env: global: - GO_REPO_FULLNAME="github.com/astaxie/beego" @@ -27,17 +27,33 @@ before_install: - cd ssdb - make - cd .. + # - prepare etcd # - prepare for etcd unit tests - - git clone https://github.com/etcd-io/etcd.git - - cd etcd - - ./build - - ./bin/etcd - - ./bin/etcdctl put current.float 1.23 - - ./bin/etcdctl put current.bool true - - ./bin/etcdctl put current.int 11 - - ./bin/etcdctl put current.string hello - - ./bin/etcdctl put current.serialize.name test - - cd .. + - rm -rf /tmp/etcd-data.tmp + - mkdir -p /tmp/etcd-data.tmp + - docker rmi gcr.io/etcd-development/etcd:v3.3.25 || true && + docker run -d + -p 2379:2379 + -p 2380:2380 + --mount type=bind,source=/tmp/etcd-data.tmp,destination=/etcd-data + --name etcd-gcr-v3.3.25 + gcr.io/etcd-development/etcd:v3.3.25 + /usr/local/bin/etcd + --name s1 + --data-dir /etcd-data + --listen-client-urls http://0.0.0.0:2379 + --advertise-client-urls http://0.0.0.0:2379 + --listen-peer-urls http://0.0.0.0:2380 + --initial-advertise-peer-urls http://0.0.0.0:2380 + --initial-cluster s1=http://0.0.0.0:2380 + --initial-cluster-token tkn + --initial-cluster-state new + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.float 1.23" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.bool true" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.int 11" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.string hello" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" install: - go get github.com/lib/pq - go get github.com/go-sql-driver/mysql @@ -64,6 +80,8 @@ install: - go get -u golang.org/x/lint/golint - go get -u github.com/go-redis/redis before_script: + + # - - psql --version # - prepare for orm unit tests - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" @@ -84,4 +102,4 @@ script: - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - golint ./... addons: - postgresql: "9.6" + postgresql: "9.6" \ No newline at end of file diff --git a/pkg/client/httplib/testing/client.go b/pkg/client/httplib/testing/client.go index 19e6cd23..863ed0e8 100644 --- a/pkg/client/httplib/testing/client.go +++ b/pkg/client/httplib/testing/client.go @@ -34,7 +34,10 @@ func getPort() string { if err != nil { return "8080" } - port = config.String("httpport") + port, err = config.String("httpport") + if err != nil { + return "8080" + } return port } return port diff --git a/pkg/infrastructure/config/config.go b/pkg/infrastructure/config/config.go index 3514e425..c7f45469 100644 --- a/pkg/infrastructure/config/config.go +++ b/pkg/infrastructure/config/config.go @@ -131,7 +131,7 @@ func (c *BaseConfiger) Float(key string) (float64, error) { // DefaultString returns the string value for a given key. // if err != nil or value is empty return defaultval func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { - if res, err := c.String(key); res != "" && err != nil { + if res, err := c.String(key); res != "" && err == nil { return res } return defaultVal @@ -140,7 +140,7 @@ func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { - if res, err := c.Strings(key); len(res) > 0 && err != nil { + if res, err := c.Strings(key); len(res) > 0 && err == nil { return res } return defaultVal diff --git a/pkg/infrastructure/config/ini.go b/pkg/infrastructure/config/ini.go index 4d3946d5..2338b3cf 100644 --- a/pkg/infrastructure/config/ini.go +++ b/pkg/infrastructure/config/ini.go @@ -17,14 +17,13 @@ package config import ( "bufio" "bytes" - "context" "errors" - "fmt" "io" "io/ioutil" "os" "os/user" "path/filepath" + "strconv" "strings" "sync" ) @@ -66,9 +65,6 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e keyComment: make(map[string]string), RWMutex: sync.RWMutex{}, } - cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { - return cfg.getdata(key) - }) cfg.Lock() defer cfg.Unlock() @@ -94,7 +90,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e break } - // It might be a good idea to throw a error on all unknonw errors? + //It might be a good idea to throw a error on all unknonw errors? if _, ok := err.(*os.PathError); ok { return nil, err } @@ -236,6 +232,101 @@ type IniConfigContainer struct { sync.RWMutex } +// Bool returns the boolean value for a given key. +func (c *IniConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getdata(key)) +} + +// DefaultBool returns the boolean value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *IniConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getdata(key)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *IniConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getdata(key), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *IniConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getdata(key), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *IniConfigContainer) String(key string) (string, error) { + return c.getdata(key), nil +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { + v, err := c.String(key) + if v == "" || err != nil { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *IniConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err + } + return strings.Split(v, ";"), nil +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v, err := c.Strings(key) + if v == nil || err != nil { + return defaultval + } + return v +} + // GetSection returns map for the given section func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { @@ -383,9 +474,9 @@ func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { } // section.key or key -func (c *IniConfigContainer) getdata(key string) (string, error) { +func (c *IniConfigContainer) getdata(key string) string { if len(key) == 0 { - return "", errors.New("the key is empty") + return "" } c.RLock() defer c.RUnlock() @@ -403,10 +494,10 @@ func (c *IniConfigContainer) getdata(key string) (string, error) { } if v, ok := c.data[section]; ok { if vv, ok := v[k]; ok { - return vv, nil + return vv } } - return "", errors.New(fmt.Sprintf("config not found: %s", key)) + return "" } func init() { diff --git a/pkg/infrastructure/config/json/json.go b/pkg/infrastructure/config/json/json.go index b552269a..975e1523 100644 --- a/pkg/infrastructure/config/json/json.go +++ b/pkg/infrastructure/config/json/json.go @@ -165,7 +165,35 @@ func (c *JSONConfigContainer) String(key string) (string, error) { return v, nil } } - return "", errors.New(fmt.Sprintf("config not found or is not string, key: %s", key)) + return "", nil +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { + // TODO FIXME should not use "" to replace non existence + if v, err := c.String(key); v != "" && err == nil { + return v + } + return defaultval +} + +// Strings returns the []string value for a given key. +func (c *JSONConfigContainer) Strings(key string) ([]string, error) { + stringVal, err := c.String(key) + if stringVal == "" || err != nil { + return nil, err + } + return strings.Split(stringVal, ";"), nil +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { + if v, err := c.Strings(key); v != nil && err == nil { + return v + } + return defaultval } // GetSection returns map for the given section diff --git a/pkg/infrastructure/config/xml/xml.go b/pkg/infrastructure/config/xml/xml.go index c095ef06..49aab33e 100644 --- a/pkg/infrastructure/config/xml/xml.go +++ b/pkg/infrastructure/config/xml/xml.go @@ -26,7 +26,7 @@ // // cnf, err := config.NewConfig("xml", "config.xml") // -// More docs http://beego.me/docs/module/config.md +//More docs http://beego.me/docs/module/config.md package xml import ( @@ -36,11 +36,11 @@ import ( "io/ioutil" "os" "strconv" + "strings" "sync" - "github.com/beego/x2j" - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/beego/x2j" ) // Config is a xml config parser and implements Config interface. @@ -148,7 +148,7 @@ func (c *ConfigContainer) String(key string) (string, error) { if v, ok := c.data[key].(string); ok { return v, nil } - return "", errors.New(fmt.Sprintf("configuration not found or not string, key: %s", key)) + return "", nil } // DefaultString returns the string value for a given key. @@ -161,6 +161,25 @@ func (c *ConfigContainer) DefaultString(key string, defaultval string) string { return v } +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err + } + return strings.Split(v, ";"), nil +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v, err := c.Strings(key) + if v == nil || err != nil { + return defaultval + } + return v +} + // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section].(map[string]interface{}); ok { diff --git a/pkg/infrastructure/config/yaml/yaml.go b/pkg/infrastructure/config/yaml/yaml.go index 96045365..ddd556e6 100644 --- a/pkg/infrastructure/config/yaml/yaml.go +++ b/pkg/infrastructure/config/yaml/yaml.go @@ -214,11 +214,9 @@ func (c *ConfigContainer) String(key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { return vv, nil - } else { - return "", errors.New(fmt.Sprintf("the value is not string, key: %s, value: %v", key, v)) } } - return "", errors.New(fmt.Sprintf("configuration not found, key: %s", key)) + return "", nil } // DefaultString returns the string value for a given key. diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index 3abe255e..b2e38a80 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -32,8 +32,8 @@ import ( // Config is the main struct for BConfig type Config struct { - AppName string //Application name - RunMode string //Running Mode: dev | prod + AppName string // Application name + RunMode string // Running Mode: dev | prod RouterCaseSensitive bool ServerName string RecoverPanic bool @@ -113,8 +113,8 @@ type SessionConfig struct { // LogConfig holds Log related config type LogConfig struct { AccessLogs bool - EnableStaticLogs bool //log static files requests default: false - AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + EnableStaticLogs bool // log static files requests default: false + AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string FileLineNum bool Outputs map[string]string // Store Adaptor : config } @@ -210,7 +210,7 @@ func newBConfig() *Config { RecoverFunc: recoverPanic, CopyRequestBody: false, EnableGzip: false, - MaxMemory: 1 << 26, //64MB + MaxMemory: 1 << 26, // 64MB EnableErrorsShow: true, EnableErrorsRender: true, Listen: Listen{ @@ -258,7 +258,7 @@ func newBConfig() *Config { SessionGCMaxLifetime: 3600, SessionProviderConfig: "", SessionDisableHTTPOnly: false, - SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionCookieLifeTime: 0, // set cookie default is the browser life SessionAutoSetCookie: true, SessionDomain: "", SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers @@ -292,11 +292,11 @@ func assignConfig(ac config.Configer) error { // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode - } else if runMode := ac.String("RunMode"); runMode != "" { + } else if runMode, err := ac.String("RunMode"); runMode != "" && err == nil { BConfig.RunMode = runMode } - if sd := ac.String("StaticDir"); sd != "" { + if sd, err := ac.String("StaticDir"); sd != "" && err == nil { BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) for _, v := range sds { @@ -308,7 +308,7 @@ func assignConfig(ac config.Configer) error { } } - if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + if sgz, err := ac.String("StaticExtensionsToGzip"); sgz != "" && err == nil { extensions := strings.Split(sgz, ",") fileExts := []string{} for _, ext := range extensions { @@ -334,7 +334,7 @@ func assignConfig(ac config.Configer) error { BConfig.WebConfig.StaticCacheFileNum = sfn } - if lo := ac.String("LogOutputs"); lo != "" { + if lo, err := ac.String("LogOutputs"); lo != "" && err == nil { // if lo is not nil or empty // means user has set his own LogOutputs // clear the default setting to BConfig.Log.Outputs @@ -349,7 +349,7 @@ func assignConfig(ac config.Configer) error { } } - //init log + // init log logs.Reset() for adaptor, config := range BConfig.Log.Outputs { err := logs.SetLogger(adaptor, config) @@ -388,7 +388,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) { pf.SetBool(ac.DefaultBool(name, pf.Bool())) case reflect.Struct: default: - //do nothing here + // do nothing here } } @@ -431,16 +431,16 @@ func (b *beegoAppConfig) Set(key, val string) error { return nil } -func (b *beegoAppConfig) String(key string) string { - if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { - return v +func (b *beegoAppConfig) String(key string) (string, error) { + if v, err := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" && err == nil { + return v, nil } return b.innerConfig.String(key) } -func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { - return v +func (b *beegoAppConfig) Strings(key string) ([]string, error) { + if v, err := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 && err == nil { + return v, nil } return b.innerConfig.Strings(key) } @@ -474,14 +474,14 @@ func (b *beegoAppConfig) Float(key string) (float64, error) { } func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v := b.String(key); v != "" { + if v, err := b.String(key); v != "" && err == nil { return v } return defaultVal } func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v := b.Strings(key); len(v) != 0 { + if v, err := b.Strings(key); len(v) != 0 && err == nil { return v } return defaultVal diff --git a/pkg/server/web/hooks.go b/pkg/server/web/hooks.go index 13194733..ae54f190 100644 --- a/pkg/server/web/hooks.go +++ b/pkg/server/web/hooks.go @@ -48,9 +48,9 @@ func registerDefaultErrorHandler() error { func registerSession() error { if BConfig.WebConfig.Session.SessionOn { var err error - sessionConfig := AppConfig.String("sessionConfig") + sessionConfig, err := AppConfig.String("sessionConfig") conf := new(session.ManagerConfig) - if sessionConfig == "" { + if sessionConfig == "" || err != nil { conf.CookieName = BConfig.WebConfig.Session.SessionName conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime diff --git a/pkg/server/web/templatefunc.go b/pkg/server/web/templatefunc.go index 6d132bf0..34d71aab 100644 --- a/pkg/server/web/templatefunc.go +++ b/pkg/server/web/templatefunc.go @@ -160,7 +160,7 @@ func NotNil(a interface{}) (isNil bool) { func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": - value = AppConfig.String(key) + value, err = AppConfig.String(key) case "Bool": value, err = AppConfig.Bool(key) case "Int": diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh index a65f00a3..d34c05a3 100644 --- a/scripts/prepare_etcd.sh +++ b/scripts/prepare_etcd.sh @@ -4,4 +4,5 @@ etcdctl put current.float 1.23 etcdctl put current.bool true etcdctl put current.int 11 etcdctl put current.string hello -etcdctl put current.serialize.name test \ No newline at end of file +etcdctl put current.serialize.name test +etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file From 03bec05714ce8f74ceb5d63a110004e3d8e8935a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 29 Aug 2020 15:05:18 +0000 Subject: [PATCH 128/207] Add contect as first parameter for all config method --- pkg/client/httplib/testing/client.go | 2 +- pkg/infrastructure/config/base_config_test.go | 24 ++--- pkg/infrastructure/config/config.go | 82 +++++++--------- pkg/infrastructure/config/etcd/config.go | 17 ++-- pkg/infrastructure/config/etcd/config_test.go | 26 ++--- pkg/infrastructure/config/fake.go | 40 ++++---- pkg/infrastructure/config/ini.go | 73 +++++++------- pkg/infrastructure/config/ini_test.go | 20 ++-- pkg/infrastructure/config/json/json.go | 59 +++++------ pkg/infrastructure/config/json/json_test.go | 36 +++---- pkg/infrastructure/config/xml/xml.go | 71 +++++++------- pkg/infrastructure/config/xml/xml_test.go | 20 ++-- pkg/infrastructure/config/yaml/yaml.go | 71 +++++++------- pkg/infrastructure/config/yaml/yaml_test.go | 20 ++-- pkg/server/web/config.go | 97 ++++++++++--------- pkg/server/web/config_test.go | 12 +-- pkg/server/web/hooks.go | 9 +- pkg/server/web/parser.go | 5 +- pkg/server/web/templatefunc.go | 23 ++--- 19 files changed, 351 insertions(+), 356 deletions(-) diff --git a/pkg/client/httplib/testing/client.go b/pkg/client/httplib/testing/client.go index 863ed0e8..00fa3059 100644 --- a/pkg/client/httplib/testing/client.go +++ b/pkg/client/httplib/testing/client.go @@ -34,7 +34,7 @@ func getPort() string { if err != nil { return "8080" } - port, err = config.String("httpport") + port, err = config.String(nil, "httpport") if err != nil { return "8080" } diff --git a/pkg/infrastructure/config/base_config_test.go b/pkg/infrastructure/config/base_config_test.go index 74a669a7..74cef184 100644 --- a/pkg/infrastructure/config/base_config_test.go +++ b/pkg/infrastructure/config/base_config_test.go @@ -24,38 +24,38 @@ import ( func TestBaseConfiger_DefaultBool(t *testing.T) { bc := newBaseConfier("true") - assert.True(t, bc.DefaultBool("key1", false)) - assert.True(t, bc.DefaultBool("key2", true)) + assert.True(t, bc.DefaultBool(context.Background(), "key1", false)) + assert.True(t, bc.DefaultBool(context.Background(), "key2", true)) } func TestBaseConfiger_DefaultFloat(t *testing.T) { bc := newBaseConfier("12.3") - assert.Equal(t, 12.3, bc.DefaultFloat("key1", 0.1)) - assert.Equal(t, 0.1, bc.DefaultFloat("key2", 0.1)) + assert.Equal(t, 12.3, bc.DefaultFloat(context.Background(), "key1", 0.1)) + assert.Equal(t, 0.1, bc.DefaultFloat(context.Background(), "key2", 0.1)) } func TestBaseConfiger_DefaultInt(t *testing.T) { bc := newBaseConfier("10") - assert.Equal(t, 10, bc.DefaultInt("key1", 8)) - assert.Equal(t, 8, bc.DefaultInt("key2", 8)) + assert.Equal(t, 10, bc.DefaultInt(context.Background(), "key1", 8)) + assert.Equal(t, 8, bc.DefaultInt(context.Background(), "key2", 8)) } func TestBaseConfiger_DefaultInt64(t *testing.T) { bc := newBaseConfier("64") - assert.Equal(t, int64(64), bc.DefaultInt64("key1", int64(8))) - assert.Equal(t, int64(8), bc.DefaultInt64("key2", int64(8))) + assert.Equal(t, int64(64), bc.DefaultInt64(context.Background(), "key1", int64(8))) + assert.Equal(t, int64(8), bc.DefaultInt64(context.Background(), "key2", int64(8))) } func TestBaseConfiger_DefaultString(t *testing.T) { bc := newBaseConfier("Hello") - assert.Equal(t, "Hello", bc.DefaultString("key1", "world")) - assert.Equal(t, "world", bc.DefaultString("key2", "world")) + assert.Equal(t, "Hello", bc.DefaultString(context.Background(), "key1", "world")) + assert.Equal(t, "world", bc.DefaultString(context.Background(), "key2", "world")) } func TestBaseConfiger_DefaultStrings(t *testing.T) { bc := newBaseConfier("Hello;world") - assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings("key1", []string{"world"})) - assert.Equal(t, []string{"world"}, bc.DefaultStrings("key2", []string{"world"})) + assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings(context.Background(), "key1", []string{"world"})) + assert.Equal(t, []string{"world"}, bc.DefaultStrings(context.Background(), "key2", []string{"world"})) } func newBaseConfier(str1 string) *BaseConfiger { diff --git a/pkg/infrastructure/config/config.go b/pkg/infrastructure/config/config.go index c7f45469..0891e571 100644 --- a/pkg/infrastructure/config/config.go +++ b/pkg/infrastructure/config/config.go @@ -54,35 +54,32 @@ import ( // Configer defines how to get and set value from configuration raw data. type Configer interface { // support section::key type in given key when using ini type. - Set(key, val string) error + Set(ctx context.Context, key, val string) error // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - String(key string) (string, error) + String(ctx context.Context, key string) (string, error) // get string slice - Strings(key string) ([]string, error) - Int(key string) (int, error) - Int64(key string) (int64, error) - Bool(key string) (bool, error) - Float(key string) (float64, error) + Strings(ctx context.Context, key string) ([]string, error) + Int(ctx context.Context, key string) (int, error) + Int64(ctx context.Context, key string) (int64, error) + Bool(ctx context.Context, key string) (bool, error) + Float(ctx context.Context, key string) (float64, error) // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultString(key string, defaultVal string) string + DefaultString(ctx context.Context, key string, defaultVal string) string // get string slice - DefaultStrings(key string, defaultVal []string) []string - DefaultInt(key string, defaultVal int) int - DefaultInt64(key string, defaultVal int64) int64 - DefaultBool(key string, defaultVal bool) bool - DefaultFloat(key string, defaultVal float64) float64 - DIY(key string) (interface{}, error) + DefaultStrings(ctx context.Context, key string, defaultVal []string) []string + DefaultInt(ctx context.Context, key string, defaultVal int) int + DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 + DefaultBool(ctx context.Context, key string, defaultVal bool) bool + DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 + DIY(ctx context.Context, key string) (interface{}, error) - GetSection(section string) (map[string]string, error) - GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) + GetSection(ctx context.Context, section string) (map[string]string, error) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error - Sub(key string) (Configer, error) + Sub(ctx context.Context, key string) (Configer, error) OnChange(ctx context.Context, key string, fn func(value string)) - // GetByPrefix(prefix string) ([]byte, error) - // GetSerializer() Serializer - SaveConfigFile(filename string) error + SaveConfigFile(ctx context.Context, filename string) error } type BaseConfiger struct { @@ -96,7 +93,7 @@ func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error } } -func (c *BaseConfiger) Int(key string) (int, error) { +func (c *BaseConfiger) Int(ctx context.Context, key string) (int, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -104,7 +101,7 @@ func (c *BaseConfiger) Int(key string) (int, error) { return strconv.Atoi(res) } -func (c *BaseConfiger) Int64(key string) (int64, error) { +func (c *BaseConfiger) Int64(ctx context.Context, key string) (int64, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -112,7 +109,7 @@ func (c *BaseConfiger) Int64(key string) (int64, error) { return strconv.ParseInt(res, 10, 64) } -func (c *BaseConfiger) Bool(key string) (bool, error) { +func (c *BaseConfiger) Bool(ctx context.Context, key string) (bool, error) { res, err := c.reader(context.TODO(), key) if err != nil { return false, err @@ -120,7 +117,7 @@ func (c *BaseConfiger) Bool(key string) (bool, error) { return ParseBool(res) } -func (c *BaseConfiger) Float(key string) (float64, error) { +func (c *BaseConfiger) Float(ctx context.Context, key string) (float64, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -130,8 +127,8 @@ func (c *BaseConfiger) Float(key string) (float64, error) { // DefaultString returns the string value for a given key. // if err != nil or value is empty return defaultval -func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { - if res, err := c.String(key); res != "" && err == nil { +func (c *BaseConfiger) DefaultString(ctx context.Context, key string, defaultVal string) string { + if res, err := c.String(ctx, key); res != "" && err == nil { return res } return defaultVal @@ -139,53 +136,48 @@ func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { - if res, err := c.Strings(key); len(res) > 0 && err == nil { +func (c *BaseConfiger) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + if res, err := c.Strings(ctx, key); len(res) > 0 && err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt(key string, defaultVal int) int { - if res, err := c.Int(key); err == nil { +func (c *BaseConfiger) DefaultInt(ctx context.Context, key string, defaultVal int) int { + if res, err := c.Int(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt64(key string, defaultVal int64) int64 { - if res, err := c.Int64(key); err == nil { +func (c *BaseConfiger) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + if res, err := c.Int64(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { - if res, err := c.Bool(key); err == nil { +func (c *BaseConfiger) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + if res, err := c.Bool(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { - if res, err := c.Float(key); err == nil { +func (c *BaseConfiger) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + if res, err := c.Float(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) { - // TODO - return nil, nil -} - -func (c *BaseConfiger) String(key string) (string, error) { +func (c *BaseConfiger) String(ctx context.Context, key string) (string, error) { return c.reader(context.TODO(), key) } // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *BaseConfiger) Strings(key string) ([]string, error) { - res, err := c.String(key) +func (c *BaseConfiger) Strings(ctx context.Context, key string) ([]string, error) { + res, err := c.String(nil, key) if err != nil || res == "" { return nil, err } @@ -198,7 +190,7 @@ func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj inter } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) Sub(key string) (Configer, error) { +func (c *BaseConfiger) Sub(ctx context.Context, key string) (Configer, error) { return nil, errors.New("unsupported operation") } diff --git a/pkg/infrastructure/config/etcd/config.go b/pkg/infrastructure/config/etcd/config.go index 30f26ce1..94057d73 100644 --- a/pkg/infrastructure/config/etcd/config.go +++ b/pkg/infrastructure/config/etcd/config.go @@ -64,23 +64,18 @@ func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { // Set do nothing and return an error // I think write data to remote config center is not a good practice -func (e *EtcdConfiger) Set(key, val string) error { +func (e *EtcdConfiger) Set(ctx context.Context, key, val string) error { return errors.New("Unsupported operation") } // DIY return the original response from etcd // be careful when you decide to use this -func (e *EtcdConfiger) DIY(key string) (interface{}, error) { +func (e *EtcdConfiger) DIY(ctx context.Context, key string) (interface{}, error) { return get(e.client, context.TODO(), key) } // GetSection in this implementation, we use section as prefix -func (e *EtcdConfiger) GetSection(section string) (map[string]string, error) { - return e.GetSectionWithCtx(context.Background(), section) -} - -func (e *EtcdConfiger) GetSectionWithCtx(ctx context.Context, section string) (map[string]string, error) { - +func (e *EtcdConfiger) GetSection(ctx context.Context, section string) (map[string]string, error) { var ( resp *clientv3.GetResponse err error @@ -103,7 +98,7 @@ func (e *EtcdConfiger) GetSectionWithCtx(ctx context.Context, section string) (m return res, nil } -func (e *EtcdConfiger) SaveConfigFile(filename string) error { +func (e *EtcdConfiger) SaveConfigFile(ctx context.Context, filename string) error { return errors.New("Unsupported operation") } @@ -111,7 +106,7 @@ func (e *EtcdConfiger) SaveConfigFile(filename string) error { // for example, when we got "5", we are not sure whether it's int 5, or it's string "5" // TODO(support more complicated decoder) func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { - res, err := e.GetSectionWithCtx(ctx, prefix) + res, err := e.GetSection(ctx, prefix) if err != nil { return errors.WithMessage(err, fmt.Sprintf("could not read config with prefix: %s", prefix)) } @@ -125,7 +120,7 @@ func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj inter } // Sub return an sub configer. -func (e *EtcdConfiger) Sub(key string) (config.Configer, error) { +func (e *EtcdConfiger) Sub(ctx context.Context, key string) (config.Configer, error) { return newEtcdConfiger(e.client, e.prefix+key), nil } diff --git a/pkg/infrastructure/config/etcd/config_test.go b/pkg/infrastructure/config/etcd/config_test.go index a9cadd95..7ccf6b96 100644 --- a/pkg/infrastructure/config/etcd/config_test.go +++ b/pkg/infrastructure/config/etcd/config_test.go @@ -42,15 +42,15 @@ func TestEtcdConfiger(t *testing.T) { provider := &EtcdConfigerProvider{} cfger, _ := provider.Parse(readEtcdConfig()) - subCfger, err := cfger.Sub("sub.") + subCfger, err := cfger.Sub(nil, "sub.") assert.Nil(t, err) assert.NotNil(t, subCfger) - subSubCfger, err := subCfger.Sub("sub.") + subSubCfger, err := subCfger.Sub(nil, "sub.") assert.NotNil(t, subSubCfger) assert.Nil(t, err) - str, err := subSubCfger.String("key1") + str, err := subSubCfger.String(nil, "key1") assert.Nil(t, err) assert.Equal(t, "sub.sub.key", str) @@ -59,37 +59,37 @@ func TestEtcdConfiger(t *testing.T) { // do nothing }) - defStr := cfger.DefaultString("not_exit", "default value") + defStr := cfger.DefaultString(nil, "not_exit", "default value") assert.Equal(t, "default value", defStr) - defInt64 := cfger.DefaultInt64("not_exit", -1) + defInt64 := cfger.DefaultInt64(nil, "not_exit", -1) assert.Equal(t, int64(-1), defInt64) - defInt := cfger.DefaultInt("not_exit", -2) + defInt := cfger.DefaultInt(nil, "not_exit", -2) assert.Equal(t, -2, defInt) - defFlt := cfger.DefaultFloat("not_exit", 12.3) + defFlt := cfger.DefaultFloat(nil, "not_exit", 12.3) assert.Equal(t, 12.3, defFlt) - defBl := cfger.DefaultBool("not_exit", true) + defBl := cfger.DefaultBool(nil, "not_exit", true) assert.True(t, defBl) - defStrs := cfger.DefaultStrings("not_exit", []string{"hello"}) + defStrs := cfger.DefaultStrings(nil, "not_exit", []string{"hello"}) assert.Equal(t, []string{"hello"}, defStrs) - fl, err := cfger.Float("current.float") + fl, err := cfger.Float(nil, "current.float") assert.Nil(t, err) assert.Equal(t, 1.23, fl) - bl, err := cfger.Bool("current.bool") + bl, err := cfger.Bool(nil, "current.bool") assert.Nil(t, err) assert.True(t, bl) - it, err := cfger.Int("current.int") + it, err := cfger.Int(nil, "current.int") assert.Nil(t, err) assert.Equal(t, 11, it) - str, err = cfger.String("current.string") + str, err = cfger.String(nil, "current.string") assert.Nil(t, err) assert.Equal(t, "hello", str) diff --git a/pkg/infrastructure/config/fake.go b/pkg/infrastructure/config/fake.go index f885d44d..b606be01 100644 --- a/pkg/infrastructure/config/fake.go +++ b/pkg/infrastructure/config/fake.go @@ -30,71 +30,71 @@ func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } -func (c *fakeConfigContainer) Set(key, val string) error { +func (c *fakeConfigContainer) Set(ctx context.Context, key, val string) error { c.data[strings.ToLower(key)] = val return nil } -func (c *fakeConfigContainer) Int(key string) (int, error) { +func (c *fakeConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.getData(key)) } -func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +func (c *fakeConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Int64(key string) (int64, error) { +func (c *fakeConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } -func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +func (c *fakeConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Bool(key string) (bool, error) { +func (c *fakeConfigContainer) Bool(ctx context.Context, key string) (bool, error) { return ParseBool(c.getData(key)) } -func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +func (c *fakeConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Float(key string) (float64, error) { +func (c *fakeConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } -func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +func (c *fakeConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { +func (c *fakeConfigContainer) DIY(ctx context.Context, key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } return nil, errors.New("key not find") } -func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *fakeConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } -func (c *fakeConfigContainer) SaveConfigFile(filename string) error { +func (c *fakeConfigContainer) SaveConfigFile(ctx context.Context, filename string) error { return errors.New("not implement in the fakeConfigContainer") } diff --git a/pkg/infrastructure/config/ini.go b/pkg/infrastructure/config/ini.go index 2338b3cf..92ed8df8 100644 --- a/pkg/infrastructure/config/ini.go +++ b/pkg/infrastructure/config/ini.go @@ -17,6 +17,7 @@ package config import ( "bufio" "bytes" + "context" "errors" "io" "io/ioutil" @@ -233,84 +234,84 @@ type IniConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *IniConfigContainer) Bool(key string) (bool, error) { +func (c *IniConfigContainer) Bool(ctx context.Context, key string) (bool, error) { return ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *IniConfigContainer) Int(key string) (int, error) { +func (c *IniConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.getdata(key)) } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *IniConfigContainer) Int64(key string) (int64, error) { +func (c *IniConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.getdata(key), 10, 64) } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *IniConfigContainer) Float(key string) (float64, error) { +func (c *IniConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.getdata(key), 64) } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *IniConfigContainer) String(key string) (string, error) { +func (c *IniConfigContainer) String(ctx context.Context, key string) (string, error) { return c.getdata(key), nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v, err := c.String(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(nil, key) if v == "" || err != nil { - return defaultval + return defaultVal } return v } // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *IniConfigContainer) Strings(key string) ([]string, error) { - v, err := c.String(key) +func (c *IniConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(nil, key) if v == "" || err != nil { return nil, err } @@ -318,17 +319,17 @@ func (c *IniConfigContainer) Strings(key string) ([]string, error) { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v, err := c.Strings(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) if v == nil || err != nil { - return defaultval + return defaultVal } return v } // GetSection returns map for the given section -func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *IniConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v, nil } @@ -338,7 +339,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro // SaveConfigFile save the config into file. // // BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. -func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *IniConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -438,7 +439,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. -func (c *IniConfigContainer) Set(key, value string) error { +func (c *IniConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() if len(key) == 0 { @@ -461,12 +462,12 @@ func (c *IniConfigContainer) Set(key, value string) error { if _, ok := c.data[section]; !ok { c.data[section] = make(map[string]string) } - c.data[section][k] = value + c.data[section][k] = val return nil } // DIY returns the raw value by a given key. -func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *IniConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } diff --git a/pkg/infrastructure/config/ini_test.go b/pkg/infrastructure/config/ini_test.go index 7daa0a6e..d4972ddd 100644 --- a/pkg/infrastructure/config/ini_test.go +++ b/pkg/infrastructure/config/ini_test.go @@ -101,19 +101,19 @@ password = ${GOPATH} var value interface{} switch v.(type) { case int: - value, err = iniconf.Int(k) + value, err = iniconf.Int(nil, k) case int64: - value, err = iniconf.Int64(k) + value, err = iniconf.Int64(nil, k) case float64: - value, err = iniconf.Float(k) + value, err = iniconf.Float(nil, k) case bool: - value, err = iniconf.Bool(k) + value, err = iniconf.Bool(nil, k) case []string: - value, err = iniconf.Strings(k) + value, err = iniconf.Strings(nil, k) case string: - value, err = iniconf.String(k) + value, err = iniconf.String(nil, k) default: - value, err = iniconf.DIY(k) + value, err = iniconf.DIY(nil, k) } if err != nil { t.Fatalf("get key %q value fail,err %s", k, err) @@ -122,10 +122,10 @@ password = ${GOPATH} } } - if err = iniconf.Set("name", "astaxie"); err != nil { + if err = iniconf.Set(nil, "name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := iniconf.String("name") + res, _ := iniconf.String(nil, "name") if res != "astaxie" { t.Fatal("get name error") } @@ -171,7 +171,7 @@ name=mysql t.Fatal(err) } name := "newIniConfig.ini" - if err := cfg.SaveConfigFile(name); err != nil { + if err := cfg.SaveConfigFile(nil, name); err != nil { t.Fatal(err) } defer os.Remove(name) diff --git a/pkg/infrastructure/config/json/json.go b/pkg/infrastructure/config/json/json.go index 975e1523..dae55118 100644 --- a/pkg/infrastructure/config/json/json.go +++ b/pkg/infrastructure/config/json/json.go @@ -15,6 +15,7 @@ package json import ( + "context" "encoding/json" "errors" "fmt" @@ -75,7 +76,7 @@ type JSONConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *JSONConfigContainer) Bool(key string) (bool, error) { +func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error) { val := c.getData(key) if val != nil { return config.ParseBool(val) @@ -85,15 +86,15 @@ func (c *JSONConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err == nil { +func (c *JSONConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + if v, err := c.Bool(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Int returns the integer value for a given key. -func (c *JSONConfigContainer) Int(key string) (int, error) { +func (c *JSONConfigContainer) Int(ctx context.Context, key string) (int, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -108,15 +109,15 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err == nil { +func (c *JSONConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + if v, err := c.Int(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Int64 returns the int64 value for a given key. -func (c *JSONConfigContainer) Int64(key string) (int64, error) { +func (c *JSONConfigContainer) Int64(ctx context.Context, key string) (int64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -129,15 +130,15 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err == nil { +func (c *JSONConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + if v, err := c.Int64(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Float returns the float value for a given key. -func (c *JSONConfigContainer) Float(key string) (float64, error) { +func (c *JSONConfigContainer) Float(ctx context.Context, key string) (float64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -150,15 +151,15 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err == nil { +func (c *JSONConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + if v, err := c.Float(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // String returns the string value for a given key. -func (c *JSONConfigContainer) String(key string) (string, error) { +func (c *JSONConfigContainer) String(ctx context.Context, key string) (string, error) { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { @@ -170,17 +171,17 @@ func (c *JSONConfigContainer) String(key string) (string, error) { // DefaultString returns the string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { +func (c *JSONConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { // TODO FIXME should not use "" to replace non existence - if v, err := c.String(key); v != "" && err == nil { + if v, err := c.String(ctx, key); v != "" && err == nil { return v } - return defaultval + return defaultVal } // Strings returns the []string value for a given key. -func (c *JSONConfigContainer) Strings(key string) ([]string, error) { - stringVal, err := c.String(key) +func (c *JSONConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + stringVal, err := c.String(nil, key) if stringVal == "" || err != nil { return nil, err } @@ -189,15 +190,15 @@ func (c *JSONConfigContainer) Strings(key string) ([]string, error) { // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v, err := c.Strings(key); v != nil && err == nil { +func (c *JSONConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + if v, err := c.Strings(ctx, key); v != nil && err == nil { return v } - return defaultval + return defaultVal } // GetSection returns map for the given section -func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *JSONConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil } @@ -205,7 +206,7 @@ func (c *JSONConfigContainer) GetSection(section string) (map[string]string, err } // SaveConfigFile save the config into file -func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *JSONConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -221,7 +222,7 @@ func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *JSONConfigContainer) Set(key, val string) error { +func (c *JSONConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -229,7 +230,7 @@ func (c *JSONConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *JSONConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { val := c.getData(key) if val != nil { return val, nil diff --git a/pkg/infrastructure/config/json/json_test.go b/pkg/infrastructure/config/json/json_test.go index cf337d20..486d2b11 100644 --- a/pkg/infrastructure/config/json/json_test.go +++ b/pkg/infrastructure/config/json/json_test.go @@ -49,7 +49,7 @@ func TestJsonStartsWithArray(t *testing.T) { if err != nil { t.Fatal(err) } - rootArray, err := jsonconf.DIY("rootArray") + rootArray, err := jsonconf.DIY(nil, "rootArray") if err != nil { t.Error("array does not exist as element") } @@ -155,19 +155,19 @@ func TestJson(t *testing.T) { var value interface{} switch v.(type) { case int: - value, err = jsonconf.Int(k) + value, err = jsonconf.Int(nil, k) case int64: - value, err = jsonconf.Int64(k) + value, err = jsonconf.Int64(nil, k) case float64: - value, err = jsonconf.Float(k) + value, err = jsonconf.Float(nil, k) case bool: - value, err = jsonconf.Bool(k) + value, err = jsonconf.Bool(nil, k) case []string: - value, err = jsonconf.Strings(k) + value, err = jsonconf.Strings(nil, k) case string: - value, err = jsonconf.String(k) + value, err = jsonconf.String(nil, k) default: - value, err = jsonconf.DIY(k) + value, err = jsonconf.DIY(nil, k) } if err != nil { t.Fatalf("get key %q value fatal,%v err %s", k, v, err) @@ -176,16 +176,16 @@ func TestJson(t *testing.T) { } } - if err = jsonconf.Set("name", "astaxie"); err != nil { + if err = jsonconf.Set(nil, "name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := jsonconf.String("name") + res, _ := jsonconf.String(nil, "name") if res != "astaxie" { t.Fatal("get name error") } - if db, err := jsonconf.DIY("database"); err != nil { + if db, err := jsonconf.DIY(nil, "database"); err != nil { t.Fatal(err) } else if m, ok := db.(map[string]interface{}); !ok { t.Log(db) @@ -196,31 +196,31 @@ func TestJson(t *testing.T) { } } - if _, err := jsonconf.Int("unknown"); err == nil { + if _, err := jsonconf.Int(nil, "unknown"); err == nil { t.Error("unknown keys should return an error when expecting an Int") } - if _, err := jsonconf.Int64("unknown"); err == nil { + if _, err := jsonconf.Int64(nil, "unknown"); err == nil { t.Error("unknown keys should return an error when expecting an Int64") } - if _, err := jsonconf.Float("unknown"); err == nil { + if _, err := jsonconf.Float(nil, "unknown"); err == nil { t.Error("unknown keys should return an error when expecting a Float") } - if _, err := jsonconf.DIY("unknown"); err == nil { + if _, err := jsonconf.DIY(nil, "unknown"); err == nil { t.Error("unknown keys should return an error when expecting an interface{}") } - if val, _ := jsonconf.String("unknown"); val != "" { + if val, _ := jsonconf.String(nil, "unknown"); val != "" { t.Error("unknown keys should return an empty string when expecting a String") } - if _, err := jsonconf.Bool("unknown"); err == nil { + if _, err := jsonconf.Bool(nil, "unknown"); err == nil { t.Error("unknown keys should return an error when expecting a Bool") } - if !jsonconf.DefaultBool("unknown", true) { + if !jsonconf.DefaultBool(nil, "unknown", true) { t.Error("unknown keys with default value wrong") } } diff --git a/pkg/infrastructure/config/xml/xml.go b/pkg/infrastructure/config/xml/xml.go index 49aab33e..e3e93b01 100644 --- a/pkg/infrastructure/config/xml/xml.go +++ b/pkg/infrastructure/config/xml/xml.go @@ -30,6 +30,7 @@ package xml import ( + "context" "encoding/xml" "errors" "fmt" @@ -80,7 +81,7 @@ type ConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { if v := c.data[key]; v != nil { return config.ParseBool(v) } @@ -88,63 +89,63 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.data[key].(string)) } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.data[key].(string), 10, 64) } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.data[key].(string), 64) } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) (string, error) { +func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { if v, ok := c.data[key].(string); ok { return v, nil } @@ -152,18 +153,18 @@ func (c *ConfigContainer) String(key string) (string, error) { } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v, err := c.String(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(nil, key) if v == "" || err != nil { - return defaultval + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) ([]string, error) { - v, err := c.String(key) +func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(ctx, key) if v == "" || err != nil { return nil, err } @@ -171,17 +172,17 @@ func (c *ConfigContainer) Strings(key string) ([]string, error) { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v, err := c.Strings(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) if v == nil || err != nil { - return defaultval + return defaultVal } return v } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section].(map[string]interface{}); ok { mapstr := make(map[string]string) for k, val := range v { @@ -193,7 +194,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -209,7 +210,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *ConfigContainer) Set(key, val string) error { +func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -217,7 +218,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil } diff --git a/pkg/infrastructure/config/xml/xml_test.go b/pkg/infrastructure/config/xml/xml_test.go index 0391efab..470280e0 100644 --- a/pkg/infrastructure/config/xml/xml_test.go +++ b/pkg/infrastructure/config/xml/xml_test.go @@ -76,7 +76,7 @@ func TestXML(t *testing.T) { } var xmlsection map[string]string - xmlsection, err = xmlconf.GetSection("mysection") + xmlsection, err = xmlconf.GetSection(nil, "mysection") if err != nil { t.Fatal(err) } @@ -94,19 +94,19 @@ func TestXML(t *testing.T) { switch v.(type) { case int: - value, err = xmlconf.Int(k) + value, err = xmlconf.Int(nil, k) case int64: - value, err = xmlconf.Int64(k) + value, err = xmlconf.Int64(nil, k) case float64: - value, err = xmlconf.Float(k) + value, err = xmlconf.Float(nil, k) case bool: - value, err = xmlconf.Bool(k) + value, err = xmlconf.Bool(nil, k) case []string: - value, err = xmlconf.Strings(k) + value, err = xmlconf.Strings(nil, k) case string: - value, err = xmlconf.String(k) + value, err = xmlconf.String(nil, k) default: - value, err = xmlconf.DIY(k) + value, err = xmlconf.DIY(nil, k) } if err != nil { t.Errorf("get key %q value fatal,%v err %s", k, v, err) @@ -116,11 +116,11 @@ func TestXML(t *testing.T) { } - if err = xmlconf.Set("name", "astaxie"); err != nil { + if err = xmlconf.Set(nil, "name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := xmlconf.String("name") + res, _ := xmlconf.String(nil, "name") if res != "astaxie" { t.Fatal("get name error") } diff --git a/pkg/infrastructure/config/yaml/yaml.go b/pkg/infrastructure/config/yaml/yaml.go index ddd556e6..1f4f1d23 100644 --- a/pkg/infrastructure/config/yaml/yaml.go +++ b/pkg/infrastructure/config/yaml/yaml.go @@ -31,6 +31,7 @@ package yaml import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -125,7 +126,7 @@ type ConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { v, err := c.getData(key) if err != nil { return false, err @@ -134,17 +135,17 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int); ok { @@ -156,17 +157,17 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int64); ok { @@ -176,17 +177,17 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { if v, err := c.getData(key); err != nil { return 0.0, err } else if vv, ok := v.(float64); ok { @@ -200,17 +201,17 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) (string, error) { +func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { return vv, nil @@ -220,18 +221,18 @@ func (c *ConfigContainer) String(key string) (string, error) { } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v, err := c.String(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(nil, key) if v == "" || err != nil { - return defaultval + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) ([]string, error) { - v, err := c.String(key) +func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(nil, key) if v == "" || err != nil { return nil, err } @@ -239,17 +240,17 @@ func (c *ConfigContainer) Strings(key string) ([]string, error) { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v, err := c.Strings(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) if v == nil || err != nil { - return defaultval + return defaultVal } return v } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil @@ -258,7 +259,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -270,7 +271,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *ConfigContainer) Set(key, val string) error { +func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -278,7 +279,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { return c.getData(key) } diff --git a/pkg/infrastructure/config/yaml/yaml_test.go b/pkg/infrastructure/config/yaml/yaml_test.go index 0fa8bc7b..197b68e4 100644 --- a/pkg/infrastructure/config/yaml/yaml_test.go +++ b/pkg/infrastructure/config/yaml/yaml_test.go @@ -70,7 +70,7 @@ func TestYaml(t *testing.T) { t.Fatal(err) } - res, _ := yamlconf.String("appname") + res, _ := yamlconf.String(nil, "appname") if res != "beeapi" { t.Fatal("appname not equal to beeapi") } @@ -84,19 +84,19 @@ func TestYaml(t *testing.T) { switch v.(type) { case int: - value, err = yamlconf.Int(k) + value, err = yamlconf.Int(nil, k) case int64: - value, err = yamlconf.Int64(k) + value, err = yamlconf.Int64(nil, k) case float64: - value, err = yamlconf.Float(k) + value, err = yamlconf.Float(nil, k) case bool: - value, err = yamlconf.Bool(k) + value, err = yamlconf.Bool(nil, k) case []string: - value, err = yamlconf.Strings(k) + value, err = yamlconf.Strings(nil, k) case string: - value, err = yamlconf.String(k) + value, err = yamlconf.String(nil, k) default: - value, err = yamlconf.DIY(k) + value, err = yamlconf.DIY(nil, k) } if err != nil { t.Errorf("get key %q value fatal,%v err %s", k, v, err) @@ -106,10 +106,10 @@ func TestYaml(t *testing.T) { } - if err = yamlconf.Set("name", "astaxie"); err != nil { + if err = yamlconf.Set(nil, "name", "astaxie"); err != nil { t.Fatal(err) } - res, _ = yamlconf.String("name") + res, _ = yamlconf.String(nil, "name") if res != "astaxie" { t.Fatal("get name error") } diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index b2e38a80..bf8db30e 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -15,6 +15,7 @@ package web import ( + context2 "context" "fmt" "os" "path/filepath" @@ -292,11 +293,11 @@ func assignConfig(ac config.Configer) error { // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode - } else if runMode, err := ac.String("RunMode"); runMode != "" && err == nil { + } else if runMode, err := ac.String(nil, "RunMode"); runMode != "" && err == nil { BConfig.RunMode = runMode } - if sd, err := ac.String("StaticDir"); sd != "" && err == nil { + if sd, err := ac.String(nil, "StaticDir"); sd != "" && err == nil { BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) for _, v := range sds { @@ -308,7 +309,7 @@ func assignConfig(ac config.Configer) error { } } - if sgz, err := ac.String("StaticExtensionsToGzip"); sgz != "" && err == nil { + if sgz, err := ac.String(nil, "StaticExtensionsToGzip"); sgz != "" && err == nil { extensions := strings.Split(sgz, ",") fileExts := []string{} for _, ext := range extensions { @@ -326,15 +327,15 @@ func assignConfig(ac config.Configer) error { } } - if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + if sfs, err := ac.Int(nil, "StaticCacheFileSize"); err == nil { BConfig.WebConfig.StaticCacheFileSize = sfs } - if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + if sfn, err := ac.Int(nil, "StaticCacheFileNum"); err == nil { BConfig.WebConfig.StaticCacheFileNum = sfn } - if lo, err := ac.String("LogOutputs"); lo != "" && err == nil { + if lo, err := ac.String(nil, "LogOutputs"); lo != "" && err == nil { // if lo is not nil or empty // means user has set his own LogOutputs // clear the default setting to BConfig.Log.Outputs @@ -381,11 +382,11 @@ func assignSingleConfig(p interface{}, ac config.Configer) { name := pt.Field(i).Name switch pf.Kind() { case reflect.String: - pf.SetString(ac.DefaultString(name, pf.String())) + pf.SetString(ac.DefaultString(nil, name, pf.String())) case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(name, pf.Int())) + pf.SetInt(ac.DefaultInt64(nil, name, pf.Int())) case reflect.Bool: - pf.SetBool(ac.DefaultBool(name, pf.Bool())) + pf.SetBool(ac.DefaultBool(nil, name, pf.Bool())) case reflect.Struct: default: // do nothing here @@ -424,105 +425,105 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err return &beegoAppConfig{innerConfig: ac}, nil } -func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(key, val) +func (b *beegoAppConfig) Set(ctx context2.Context, key, val string) error { + if err := b.innerConfig.Set(nil, BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(nil, key, val) } return nil } -func (b *beegoAppConfig) String(key string) (string, error) { - if v, err := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" && err == nil { +func (b *beegoAppConfig) String(ctx context2.Context, key string) (string, error) { + if v, err := b.innerConfig.String(nil, BConfig.RunMode+"::"+key); v != "" && err == nil { return v, nil } - return b.innerConfig.String(key) + return b.innerConfig.String(nil, key) } -func (b *beegoAppConfig) Strings(key string) ([]string, error) { - if v, err := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 && err == nil { +func (b *beegoAppConfig) Strings(ctx context2.Context, key string) ([]string, error) { + if v, err := b.innerConfig.Strings(nil, BConfig.RunMode+"::"+key); len(v) > 0 && err == nil { return v, nil } - return b.innerConfig.Strings(key) + return b.innerConfig.Strings(nil, key) } -func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Int(ctx context2.Context, key string) (int, error) { + if v, err := b.innerConfig.Int(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Int(key) + return b.innerConfig.Int(nil, key) } -func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Int64(ctx context2.Context, key string) (int64, error) { + if v, err := b.innerConfig.Int64(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Int64(key) + return b.innerConfig.Int64(nil, key) } -func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Bool(ctx context2.Context, key string) (bool, error) { + if v, err := b.innerConfig.Bool(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Bool(key) + return b.innerConfig.Bool(nil, key) } -func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Float(ctx context2.Context, key string) (float64, error) { + if v, err := b.innerConfig.Float(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Float(key) + return b.innerConfig.Float(nil, key) } -func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v, err := b.String(key); v != "" && err == nil { +func (b *beegoAppConfig) DefaultString(ctx context2.Context, key string, defaultVal string) string { + if v, err := b.String(nil, key); v != "" && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v, err := b.Strings(key); len(v) != 0 && err == nil { +func (b *beegoAppConfig) DefaultStrings(ctx context2.Context, key string, defaultVal []string) []string { + if v, err := b.Strings(ctx, key); len(v) != 0 && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { - if v, err := b.Int(key); err == nil { +func (b *beegoAppConfig) DefaultInt(ctx context2.Context, key string, defaultVal int) int { + if v, err := b.Int(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { - if v, err := b.Int64(key); err == nil { +func (b *beegoAppConfig) DefaultInt64(ctx context2.Context, key string, defaultVal int64) int64 { + if v, err := b.Int64(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { - if v, err := b.Bool(key); err == nil { +func (b *beegoAppConfig) DefaultBool(ctx context2.Context, key string, defaultVal bool) bool { + if v, err := b.Bool(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { - if v, err := b.Float(key); err == nil { +func (b *beegoAppConfig) DefaultFloat(ctx context2.Context, key string, defaultVal float64) float64 { + if v, err := b.Float(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(key) +func (b *beegoAppConfig) DIY(ctx context2.Context, key string) (interface{}, error) { + return b.innerConfig.DIY(nil, key) } -func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(section) +func (b *beegoAppConfig) GetSection(ctx context2.Context, section string) (map[string]string, error) { + return b.innerConfig.GetSection(nil, section) } -func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(filename) +func (b *beegoAppConfig) SaveConfigFile(ctx context2.Context, filename string) error { + return b.innerConfig.SaveConfigFile(nil, filename) } diff --git a/pkg/server/web/config_test.go b/pkg/server/web/config_test.go index 1d6de695..4961d3a9 100644 --- a/pkg/server/web/config_test.go +++ b/pkg/server/web/config_test.go @@ -111,12 +111,12 @@ func TestAssignConfig_02(t *testing.T) { func TestAssignConfig_03(t *testing.T) { jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) - ac.Set("AppName", "test_app") - ac.Set("RunMode", "online") - ac.Set("StaticDir", "download:down download2:down2") - ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") - ac.Set("StaticCacheFileSize", "87456") - ac.Set("StaticCacheFileNum", "1254") + ac.Set(nil, "AppName", "test_app") + ac.Set(nil, "RunMode", "online") + ac.Set(nil, "StaticDir", "download:down download2:down2") + ac.Set(nil, "StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set(nil, "StaticCacheFileSize", "87456") + ac.Set(nil, "StaticCacheFileNum", "1254") assignConfig(ac) t.Logf("%#v", BConfig) diff --git a/pkg/server/web/hooks.go b/pkg/server/web/hooks.go index ae54f190..080b2006 100644 --- a/pkg/server/web/hooks.go +++ b/pkg/server/web/hooks.go @@ -1,6 +1,7 @@ package web import ( + context2 "context" "encoding/json" "mime" "net/http" @@ -48,7 +49,7 @@ func registerDefaultErrorHandler() error { func registerSession() error { if BConfig.WebConfig.Session.SessionOn { var err error - sessionConfig, err := AppConfig.String("sessionConfig") + sessionConfig, err := AppConfig.String(nil, "sessionConfig") conf := new(session.ManagerConfig) if sessionConfig == "" || err != nil { conf.CookieName = BConfig.WebConfig.Session.SessionName @@ -96,9 +97,9 @@ func registerAdmin() error { func registerGzip() error { if BConfig.EnableGzip { context.InitGzip( - AppConfig.DefaultInt("gzipMinLength", -1), - AppConfig.DefaultInt("gzipCompressLevel", -1), - AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + AppConfig.DefaultInt(context2.Background(), "gzipMinLength", -1), + AppConfig.DefaultInt(context2.Background(), "gzipCompressLevel", -1), + AppConfig.DefaultStrings(context2.Background(), "includedMethods", []string{"GET"}), ) } return nil diff --git a/pkg/server/web/parser.go b/pkg/server/web/parser.go index ce63a0be..a4507010 100644 --- a/pkg/server/web/parser.go +++ b/pkg/server/web/parser.go @@ -15,6 +15,7 @@ package web import ( + "context" "encoding/json" "errors" "fmt" @@ -516,7 +517,7 @@ func genRouterCode(pkgRealpath string) { } defer f.Close() - routersDir := AppConfig.DefaultString("routersdir", "routers") + routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) @@ -585,7 +586,7 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { func getRouterDir(pkgRealpath string) string { dir := filepath.Dir(pkgRealpath) for { - routersDir := AppConfig.DefaultString("routersdir", "routers") + routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") d := filepath.Join(dir, routersDir) if utils.FileExists(d) { return d diff --git a/pkg/server/web/templatefunc.go b/pkg/server/web/templatefunc.go index 34d71aab..f3301e50 100644 --- a/pkg/server/web/templatefunc.go +++ b/pkg/server/web/templatefunc.go @@ -15,6 +15,7 @@ package web import ( + "context" "errors" "fmt" "html" @@ -58,11 +59,11 @@ func HTML2str(html string) string { re := regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) - //remove STYLE + // remove STYLE re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") - //remove SCRIPT + // remove SCRIPT re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") @@ -85,7 +86,7 @@ func DateFormat(t time.Time, layout string) (datestring string) { var datePatterns = []string{ // year "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 + "y", "06", // A two digit representation of a year Examples: 99 or 03 // month "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 @@ -160,17 +161,17 @@ func NotNil(a interface{}) (isNil bool) { func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": - value, err = AppConfig.String(key) + value, err = AppConfig.String(context.Background(), key) case "Bool": - value, err = AppConfig.Bool(key) + value, err = AppConfig.Bool(context.Background(), key) case "Int": - value, err = AppConfig.Int(key) + value, err = AppConfig.Int(context.Background(), key) case "Int64": - value, err = AppConfig.Int64(key) + value, err = AppConfig.Int64(context.Background(), key) case "Float": - value, err = AppConfig.Float(key) + value, err = AppConfig.Float(context.Background(), key) case "DIY": - value, err = AppConfig.DIY(key) + value, err = AppConfig.DIY(context.Background(), key) default: err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") } @@ -201,7 +202,7 @@ func Str2html(raw string) template.HTML { // Htmlquote returns quoted html string. func Htmlquote(text string) string { - //HTML编码为实体符号 + // HTML编码为实体符号 /* Encodes `text` for raw use in HTML. >>> htmlquote("<'&\\">") @@ -220,7 +221,7 @@ func Htmlquote(text string) string { // Htmlunquote returns unquoted html string. func Htmlunquote(text string) string { - //实体符号解释为HTML + // 实体符号解释为HTML /* Decodes `text` that's HTML quoted. >>> htmlunquote('<'&">') From 670064686e1afb78d7b43dfedbd165bd3ca2611e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 30 Aug 2020 15:39:07 +0000 Subject: [PATCH 129/207] Add ctx to session API --- pkg/client/cache/redis/redis_test.go | 4 +- .../session/couchbase/sess_couchbase.go | 27 ++-- .../session/ledis/ledis_session.go | 29 +++-- .../session/memcache/sess_memcache.go | 27 ++-- .../session/mysql/sess_mysql.go | 27 ++-- .../session/postgres/sess_postgresql.go | 27 ++-- .../session/redis/sess_redis.go | 29 +++-- .../session/redis/sess_redis_test.go | 26 ++-- .../session/redis_cluster/redis_cluster.go | 29 +++-- .../redis_sentinel/sess_redis_sentinel.go | 29 +++-- .../sess_redis_sentinel_test.go | 26 ++-- pkg/infrastructure/session/sess_cookie.go | 27 ++-- .../session/sess_cookie_test.go | 10 +- pkg/infrastructure/session/sess_file.go | 27 ++-- pkg/infrastructure/session/sess_file_test.go | 123 +++++++++--------- pkg/infrastructure/session/sess_mem.go | 27 ++-- pkg/infrastructure/session/sess_mem_test.go | 6 +- pkg/infrastructure/session/session.go | 47 +++---- pkg/infrastructure/session/ssdb/sess_ssdb.go | 27 ++-- pkg/server/web/context/input.go | 2 +- pkg/server/web/context/output.go | 2 +- pkg/server/web/controller.go | 10 +- pkg/server/web/router.go | 2 +- 23 files changed, 302 insertions(+), 288 deletions(-) diff --git a/pkg/client/cache/redis/redis_test.go b/pkg/client/cache/redis/redis_test.go index 00206157..dc0ca40f 100644 --- a/pkg/client/cache/redis/redis_test.go +++ b/pkg/client/cache/redis/redis_test.go @@ -129,7 +129,7 @@ func TestCache_Scan(t *testing.T) { t.Error("init err") } // insert all - for i := 0; i < 10000; i++ { + for i := 0; i < 100; i++ { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { t.Error("set Error", err) } @@ -141,7 +141,7 @@ func TestCache_Scan(t *testing.T) { t.Error("scan Error", err) } - assert.Equal(t, 10000, len(keys), "scan all error") + assert.Equal(t, 100, len(keys), "scan all error") // clear all if err = bm.ClearAll(); err != nil { diff --git a/pkg/infrastructure/session/couchbase/sess_couchbase.go b/pkg/infrastructure/session/couchbase/sess_couchbase.go index 378cfc9f..ddb4be58 100644 --- a/pkg/infrastructure/session/couchbase/sess_couchbase.go +++ b/pkg/infrastructure/session/couchbase/sess_couchbase.go @@ -33,6 +33,7 @@ package couchbase import ( + "context" "net/http" "strings" "sync" @@ -63,7 +64,7 @@ type Provider struct { } // Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { +func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values[key] = value @@ -71,7 +72,7 @@ func (cs *SessionStore) Set(key, value interface{}) error { } // Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { +func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { cs.lock.RLock() defer cs.lock.RUnlock() if v, ok := cs.values[key]; ok { @@ -81,7 +82,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} { } // Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { +func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() delete(cs.values, key) @@ -89,7 +90,7 @@ func (cs *SessionStore) Delete(key interface{}) error { } // Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { +func (cs *SessionStore) Flush(context.Context) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values = make(map[interface{}]interface{}) @@ -97,12 +98,12 @@ func (cs *SessionStore) Flush() error { } // SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { +func (cs *SessionStore) SessionID(context.Context) string { return cs.sid } // SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer cs.b.Close() bo, err := session.EncodeGob(cs.values) @@ -135,7 +136,7 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL // e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { cp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -152,7 +153,7 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { +func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { cp.b = cp.getBucket() var ( @@ -179,7 +180,7 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) (bool, error) { +func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() @@ -192,7 +193,7 @@ func (cp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -225,7 +226,7 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { +func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() @@ -234,11 +235,11 @@ func (cp *Provider) SessionDestroy(sid string) error { } // SessionGC Recycle -func (cp *Provider) SessionGC() { +func (cp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (cp *Provider) SessionAll() int { +func (cp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/ledis/ledis_session.go b/pkg/infrastructure/session/ledis/ledis_session.go index 96e6efa3..74bf9b65 100644 --- a/pkg/infrastructure/session/ledis/ledis_session.go +++ b/pkg/infrastructure/session/ledis/ledis_session.go @@ -2,6 +2,7 @@ package ledis import ( + "context" "net/http" "strconv" "strings" @@ -27,7 +28,7 @@ type SessionStore struct { } // Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { +func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values[key] = value @@ -35,7 +36,7 @@ func (ls *SessionStore) Set(key, value interface{}) error { } // Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { +func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} { ls.lock.RLock() defer ls.lock.RUnlock() if v, ok := ls.values[key]; ok { @@ -45,7 +46,7 @@ func (ls *SessionStore) Get(key interface{}) interface{} { } // Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { +func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() delete(ls.values, key) @@ -53,7 +54,7 @@ func (ls *SessionStore) Delete(key interface{}) error { } // Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { +func (ls *SessionStore) Flush(context.Context) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values = make(map[interface{}]interface{}) @@ -61,12 +62,12 @@ func (ls *SessionStore) Flush() error { } // SessionID get ledis session id -func (ls *SessionStore) SessionID() string { +func (ls *SessionStore) SessionID(context.Context) string { return ls.sid } // SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { +func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(ls.values) if err != nil { return @@ -85,7 +86,7 @@ type Provider struct { // SessionInit init ledis session // savepath like ledis server saveDataPath,pool size // e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { var err error lp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") @@ -111,7 +112,7 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { +func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var ( kv map[interface{}]interface{} err error @@ -132,13 +133,13 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) (bool, error) { +func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { count, _ := c.Exists([]byte(sid)) return count != 0, nil } // SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { count, _ := c.Exists([]byte(sid)) if count == 0 { // oldsid doesn't exists, set the new sid directly @@ -151,21 +152,21 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Set([]byte(sid), data) c.Expire([]byte(sid), lp.maxlifetime) } - return lp.SessionRead(sid) + return lp.SessionRead(context.Background(), sid) } // SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { +func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error { c.Del([]byte(sid)) return nil } // SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { +func (lp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (lp *Provider) SessionAll() int { +func (lp *Provider) SessionAll(context.Context) int { return 0 } func init() { diff --git a/pkg/infrastructure/session/memcache/sess_memcache.go b/pkg/infrastructure/session/memcache/sess_memcache.go index 0758c43f..57df2844 100644 --- a/pkg/infrastructure/session/memcache/sess_memcache.go +++ b/pkg/infrastructure/session/memcache/sess_memcache.go @@ -33,6 +33,7 @@ package memcache import ( + "context" "net/http" "strings" "sync" @@ -54,7 +55,7 @@ type SessionStore struct { } // Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -62,7 +63,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -72,7 +73,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -80,7 +81,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -88,12 +89,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get memcache session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -113,7 +114,7 @@ type MemProvider struct { // SessionInit init memcache session // savepath like // e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime rp.conninfo = strings.Split(savePath, ";") client = memcache.New(rp.conninfo...) @@ -121,7 +122,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { +func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -149,7 +150,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) (bool, error) { +func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { return false, err @@ -162,7 +163,7 @@ func (rp *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -201,7 +202,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } // SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { +func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error { if client == nil { if err := rp.connectInit(); err != nil { return err @@ -217,11 +218,11 @@ func (rp *MemProvider) connectInit() error { } // SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { +func (rp *MemProvider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { +func (rp *MemProvider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/mysql/sess_mysql.go b/pkg/infrastructure/session/mysql/sess_mysql.go index 2dadd317..fe1d69dc 100644 --- a/pkg/infrastructure/session/mysql/sess_mysql.go +++ b/pkg/infrastructure/session/mysql/sess_mysql.go @@ -41,6 +41,7 @@ package mysql import ( + "context" "database/sql" "net/http" "sync" @@ -67,7 +68,7 @@ type SessionStore struct { // Set value in mysql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -75,7 +76,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -85,7 +86,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -93,7 +94,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in mysql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -101,13 +102,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save mysql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -134,14 +135,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init mysql session. // savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte @@ -164,7 +165,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) @@ -180,7 +181,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte @@ -203,7 +204,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() @@ -211,14 +212,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() } // SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/infrastructure/session/postgres/sess_postgresql.go b/pkg/infrastructure/session/postgres/sess_postgresql.go index adcf647b..2fadbed0 100644 --- a/pkg/infrastructure/session/postgres/sess_postgresql.go +++ b/pkg/infrastructure/session/postgres/sess_postgresql.go @@ -51,6 +51,7 @@ package postgres import ( + "context" "database/sql" "net/http" "sync" @@ -73,7 +74,7 @@ type SessionStore struct { // Set value in postgresql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -81,7 +82,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -91,7 +92,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -99,7 +100,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -107,13 +108,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save postgresql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -141,14 +142,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init postgresql session. // savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte @@ -178,7 +179,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) @@ -194,7 +195,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", oldsid) var sessiondata []byte @@ -218,7 +219,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=$1", sid) c.Close() @@ -226,14 +227,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() } // SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/infrastructure/session/redis/sess_redis.go b/pkg/infrastructure/session/redis/sess_redis.go index e775102c..c7bfbcbf 100644 --- a/pkg/infrastructure/session/redis/sess_redis.go +++ b/pkg/infrastructure/session/redis/sess_redis.go @@ -33,6 +33,7 @@ package redis import ( + "context" "net/http" "strconv" "strings" @@ -59,7 +60,7 @@ type SessionStore struct { } // Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -67,7 +68,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -77,7 +78,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -85,7 +86,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -93,12 +94,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second // e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -185,7 +186,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() @@ -205,7 +206,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { @@ -215,7 +216,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, _ := c.Exists(oldsid).Result(); existed == 0 { // oldsid doesn't exists, set the new sid directly @@ -226,11 +227,11 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) @@ -238,11 +239,11 @@ func (rp *Provider) SessionDestroy(sid string) error { } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis/sess_redis_test.go b/pkg/infrastructure/session/redis/sess_redis_test.go index ef466eab..df77204d 100644 --- a/pkg/infrastructure/session/redis/sess_redis_test.go +++ b/pkg/infrastructure/session/redis/sess_redis_test.go @@ -40,57 +40,57 @@ func TestRedis(t *testing.T) { if err != nil { t.Fatal("session start failed:", err) } - defer sess.SessionRelease(w) + defer sess.SessionRelease(nil, w) // SET AND GET - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set username failed:", err) } - username := sess.Get("username") + username := sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } // DELETE - err = sess.Delete("username") + err = sess.Delete(nil, "username") if err != nil { t.Fatal("delete username failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("delete username failed") } // FLUSH - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set failed:", err) } - err = sess.Set("password", "1qaz2wsx") + err = sess.Set(nil, "password", "1qaz2wsx") if err != nil { t.Fatal("set failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } - password := sess.Get("password") + password := sess.Get(nil, "password") if password != "1qaz2wsx" { t.Fatal("get password failed") } - err = sess.Flush() + err = sess.Flush(nil) if err != nil { t.Fatal("flush failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("flush failed") } - password = sess.Get("password") + password = sess.Get(nil, "password") if password != nil { t.Fatal("flush failed") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) } diff --git a/pkg/infrastructure/session/redis_cluster/redis_cluster.go b/pkg/infrastructure/session/redis_cluster/redis_cluster.go index 40487d76..95907a5f 100644 --- a/pkg/infrastructure/session/redis_cluster/redis_cluster.go +++ b/pkg/infrastructure/session/redis_cluster/redis_cluster.go @@ -33,6 +33,7 @@ package redis_cluster import ( + "context" "net/http" "strconv" "strings" @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -122,7 +123,7 @@ type Provider struct { // SessionInit init redis_cluster session // savepath like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -182,7 +183,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != rediss.Nil { @@ -201,7 +202,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -210,7 +211,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -222,22 +223,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go index 1f6ebaa7..1b9c841b 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go @@ -33,6 +33,7 @@ package redis_sentinel import ( + "context" "net/http" "strconv" "strings" @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis_sentinel session // savepath like redis sentinel addr,pool size,password,dbnum,masterName // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -195,7 +196,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != redis.Nil { @@ -214,7 +215,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -223,7 +224,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -235,22 +236,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go index 0dc3520a..fcec9806 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go @@ -33,58 +33,58 @@ func TestRedisSentinel(t *testing.T) { if err != nil { t.Fatal("session start failed:", err) } - defer sess.SessionRelease(w) + defer sess.SessionRelease(nil, w) // SET AND GET - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set username failed:", err) } - username := sess.Get("username") + username := sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } // DELETE - err = sess.Delete("username") + err = sess.Delete(nil, "username") if err != nil { t.Fatal("delete username failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("delete username failed") } // FLUSH - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set failed:", err) } - err = sess.Set("password", "1qaz2wsx") + err = sess.Set(nil, "password", "1qaz2wsx") if err != nil { t.Fatal("set failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } - password := sess.Get("password") + password := sess.Get(nil, "password") if password != "1qaz2wsx" { t.Fatal("get password failed") } - err = sess.Flush() + err = sess.Flush(nil) if err != nil { t.Fatal("flush failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("flush failed") } - password = sess.Get("password") + password = sess.Get(nil, "password") if password != nil { t.Fatal("flush failed") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) } diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go index 30a7032e..ffb19fb7 100644 --- a/pkg/infrastructure/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -15,6 +15,7 @@ package session import ( + "context" "crypto/aes" "crypto/cipher" "encoding/json" @@ -34,7 +35,7 @@ type CookieSessionStore struct { // Set value to cookie session. // the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { +func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error { } // Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { +func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} { } // Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { +func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error { } // Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { +func (st *CookieSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -68,12 +69,12 @@ func (st *CookieSessionStore) Flush() error { } // SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { +func (st *CookieSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { st.lock.Lock() encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) st.lock.Unlock() @@ -112,7 +113,7 @@ type CookieProvider struct { // securityName - recognized name in encoded cookie string // cookieName - cookie name // maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { +func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error { pder.config = &cookieConfig{} err := json.Unmarshal([]byte(config), pder.config) if err != nil { @@ -134,7 +135,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error // SessionRead Get SessionStore in cooke. // decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { +func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) { maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, @@ -147,26 +148,26 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) (bool, error) { +func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) { return true, nil } // SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { return nil, nil } // SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { +func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error { return nil } // SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { +func (pder *CookieProvider) SessionGC(context.Context) { } // SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { +func (pder *CookieProvider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/sess_cookie_test.go b/pkg/infrastructure/session/sess_cookie_test.go index b6726005..a9fc876d 100644 --- a/pkg/infrastructure/session/sess_cookie_test.go +++ b/pkg/infrastructure/session/sess_cookie_test.go @@ -38,14 +38,14 @@ func TestCookie(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { t.Fatal("setcookie error") } else { @@ -85,7 +85,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start err,", err) } - if newSession.SessionID() != session.SessionID() { + if newSession.SessionID(nil) != session.SessionID(nil) { t.Fatal("get cookie session id is not the same again.") } @@ -99,7 +99,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start error") } - if newSession.SessionID() == session.SessionID() { + if newSession.SessionID(nil) == session.SessionID(nil) { t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") } } diff --git a/pkg/infrastructure/session/sess_file.go b/pkg/infrastructure/session/sess_file.go index 37d5bd68..90de9a79 100644 --- a/pkg/infrastructure/session/sess_file.go +++ b/pkg/infrastructure/session/sess_file.go @@ -15,6 +15,7 @@ package session import ( + "context" "errors" "fmt" "io/ioutil" @@ -40,7 +41,7 @@ type FileSessionStore struct { } // Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { +func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values[key] = value @@ -48,7 +49,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error { } // Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { +func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} { fs.lock.RLock() defer fs.lock.RUnlock() if v, ok := fs.values[key]; ok { @@ -58,7 +59,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} { } // Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { +func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() delete(fs.values, key) @@ -66,7 +67,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error { } // Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { +func (fs *FileSessionStore) Flush(context.Context) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values = make(map[interface{}]interface{}) @@ -74,12 +75,12 @@ func (fs *FileSessionStore) Flush() error { } // SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { +func (fs *FileSessionStore) SessionID(context.Context) string { return fs.sid } // SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { +func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { filepder.lock.Lock() defer filepder.lock.Unlock() b, err := EncodeGob(fs.values) @@ -119,7 +120,7 @@ type FileProvider struct { // SessionInit Init file session provider. // savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { +func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { fp.maxlifetime = maxlifetime fp.savePath = savePath return nil @@ -128,7 +129,7 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { // SessionRead Read file session by sid. // if file is not exist, create it. // the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { +func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) { invalidChars := "./" if strings.ContainsAny(sid, invalidChars) { return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) @@ -176,7 +177,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) (bool, error) { +func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -190,7 +191,7 @@ func (fp *FileProvider) SessionExist(sid string) (bool, error) { } // SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { +func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) @@ -198,7 +199,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error { } // SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { +func (fp *FileProvider) SessionGC(context.Context) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -208,7 +209,7 @@ func (fp *FileProvider) SessionGC() { // SessionAll Get active file session number. // it walks save path to count files. -func (fp *FileProvider) SessionAll() int { +func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { return a.visit(path, f, err) @@ -222,7 +223,7 @@ func (fp *FileProvider) SessionAll() int { // SessionRegenerate Generate new sid for file session. // it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() diff --git a/pkg/infrastructure/session/sess_file_test.go b/pkg/infrastructure/session/sess_file_test.go index a27d30a6..f40de69f 100644 --- a/pkg/infrastructure/session/sess_file_test.go +++ b/pkg/infrastructure/session/sess_file_test.go @@ -15,6 +15,7 @@ package session import ( + "context" "fmt" "os" "sync" @@ -37,7 +38,7 @@ func TestFileProvider_SessionInit(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) if fp.maxlifetime != 180 { t.Error() } @@ -54,9 +55,9 @@ func TestFileProvider_SessionExist(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -64,12 +65,12 @@ func TestFileProvider_SessionExist(t *testing.T) { t.Error() } - _, err = fp.SessionRead(sid) + _, err = fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -85,9 +86,9 @@ func TestFileProvider_SessionExist2(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -95,7 +96,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { t.Error() } - exists, err = fp.SessionExist("") + exists, err = fp.SessionExist(context.Background(), "") if err == nil { t.Error() } @@ -103,7 +104,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { t.Error() } - exists, err = fp.SessionExist("1") + exists, err = fp.SessionExist(context.Background(), "1") if err == nil { t.Error() } @@ -119,15 +120,15 @@ func TestFileProvider_SessionRead(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - s, err := fp.SessionRead(sid) + s, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - _ = s.Set("sessionValue", 18975) - v := s.Get("sessionValue") + _ = s.Set(nil, "sessionValue", 18975) + v := s.Get(nil, "sessionValue") if v.(int) != 18975 { t.Error() @@ -141,14 +142,14 @@ func TestFileProvider_SessionRead1(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead("") + _, err := fp.SessionRead(context.Background(), "") if err == nil { t.Error(err) } - _, err = fp.SessionRead("1") + _, err = fp.SessionRead(context.Background(), "1") if err == nil { t.Error(err) } @@ -161,18 +162,18 @@ func TestFileProvider_SessionAll(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 546 for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } } - if fp.SessionAll() != sessionCount { + if fp.SessionAll(nil) != sessionCount { t.Error() } } @@ -184,14 +185,14 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead(sid) + _, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -199,12 +200,12 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error() } - _, err = fp.SessionRegenerate(sid, sidNew) + _, err = fp.SessionRegenerate(context.Background(), sid, sidNew) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -212,7 +213,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error() } - exists, err = fp.SessionExist(sidNew) + exists, err = fp.SessionExist(context.Background(), sidNew) if err != nil { t.Error(err) } @@ -228,14 +229,14 @@ func TestFileProvider_SessionDestroy(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead(sid) + _, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -243,12 +244,12 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error() } - err = fp.SessionDestroy(sid) + err = fp.SessionDestroy(context.Background(), sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -264,12 +265,12 @@ func TestFileProvider_SessionGC(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(1, sessionPath) + _ = fp.SessionInit(context.Background(), 1, sessionPath) sessionCount := 412 for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } @@ -277,8 +278,8 @@ func TestFileProvider_SessionGC(t *testing.T) { time.Sleep(2 * time.Second) - fp.SessionGC() - if fp.SessionAll() != 0 { + fp.SessionGC(nil) + if fp.SessionAll(nil) != 0 { t.Error() } } @@ -290,12 +291,12 @@ func TestFileSessionStore_Set(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - err := s.Set(i, i) + err := s.Set(nil, i, i) if err != nil { t.Error(err) } @@ -309,14 +310,14 @@ func TestFileSessionStore_Get(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) + _ = s.Set(nil, i, i) - v := s.Get(i) + v := s.Get(nil, i) if v.(int) != i { t.Error() } @@ -330,18 +331,18 @@ func TestFileSessionStore_Delete(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - s, _ := fp.SessionRead(sid) - s.Set("1", 1) + s, _ := fp.SessionRead(context.Background(), sid) + s.Set(nil, "1", 1) - if s.Get("1") == nil { + if s.Get(nil, "1") == nil { t.Error() } - s.Delete("1") + s.Delete(nil, "1") - if s.Get("1") != nil { + if s.Get(nil, "1") != nil { t.Error() } } @@ -353,18 +354,18 @@ func TestFileSessionStore_Flush(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) + _ = s.Set(nil, i, i) } - _ = s.Flush() + _ = s.Flush(nil) for i := 1; i <= sessionCount; i++ { - if s.Get(i) != nil { + if s.Get(nil, i) != nil { t.Error() } } @@ -377,16 +378,16 @@ func TestFileSessionStore_SessionID(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 85 for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) { t.Error(err) } } @@ -399,27 +400,27 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) filepder.savePath = sessionPath sessionCount := 85 for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - s.Set(i, i) - s.SessionRelease(nil) + s.Set(nil, i, i) + s.SessionRelease(nil, nil) } for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - if s.Get(i).(int) != i { + if s.Get(nil, i).(int) != i { t.Error() } } diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go index bd69ff80..9a27c331 100644 --- a/pkg/infrastructure/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -16,6 +16,7 @@ package session import ( "container/list" + "context" "net/http" "sync" "time" @@ -33,7 +34,7 @@ type MemSessionStore struct { } // Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { +func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.value[key] = value @@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error { } // Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { +func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.value[key]; ok { @@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} { } // Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { +func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.value, key) @@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error { } // Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { +func (st *MemSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.value = make(map[interface{}]interface{}) @@ -67,12 +68,12 @@ func (st *MemSessionStore) Flush() error { } // SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { +func (st *MemSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { } // MemProvider Implement the provider interface @@ -85,14 +86,14 @@ type MemProvider struct { } // SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { pder.maxlifetime = maxlifetime pder.savePath = savePath return nil } // SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { +func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { go pder.SessionUpdate(sid) @@ -109,7 +110,7 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) (bool, error) { +func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { @@ -119,7 +120,7 @@ func (pder *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { go pder.SessionUpdate(oldsid) @@ -141,7 +142,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { } // SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { +func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { @@ -153,7 +154,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error { } // SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { +func (pder *MemProvider) SessionGC(context.Context) { pder.lock.RLock() for { element := pder.list.Back() @@ -175,7 +176,7 @@ func (pder *MemProvider) SessionGC() { } // SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { +func (pder *MemProvider) SessionAll(context.Context) int { return pder.list.Len() } diff --git a/pkg/infrastructure/session/sess_mem_test.go b/pkg/infrastructure/session/sess_mem_test.go index 2e8934b8..e6d35476 100644 --- a/pkg/infrastructure/session/sess_mem_test.go +++ b/pkg/infrastructure/session/sess_mem_test.go @@ -36,12 +36,12 @@ func TestMem(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - defer sess.SessionRelease(w) - err = sess.Set("username", "astaxie") + defer sess.SessionRelease(nil, w) + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { diff --git a/pkg/infrastructure/session/session.go b/pkg/infrastructure/session/session.go index 92e35de4..bb7e5bd6 100644 --- a/pkg/infrastructure/session/session.go +++ b/pkg/infrastructure/session/session.go @@ -28,6 +28,7 @@ package session import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -43,24 +44,24 @@ import ( // Store contains all data for one session process with specific id. type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data + Set(ctx context.Context, key, value interface{}) error //set session value + Get(ctx context.Context, key interface{}) interface{} //get session value + Delete(ctx context.Context, key interface{}) error //delete session value + SessionID(ctx context.Context) string //back current sessionID + SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush(ctx context.Context) error //delete all data } // Provider contains global session methods and saved SessionStores. // it can operate a SessionStore by its id. type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) (bool, error) - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() + SessionInit(ctx context.Context, gclifetime int64, config string) error + SessionRead(ctx context.Context, sid string) (Store, error) + SessionExist(ctx context.Context, sid string) (bool, error) + SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) + SessionDestroy(ctx context.Context, sid string) error + SessionAll(ctx context.Context) int //get all active session + SessionGC(ctx context.Context) } var provides = make(map[string]Provider) @@ -148,7 +149,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { } } - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig) if err != nil { return nil, err } @@ -212,12 +213,12 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } if sid != "" { - exists, err := manager.provider.SessionExist(sid) + exists, err := manager.provider.SessionExist(nil, sid) if err != nil { return nil, err } if exists { - return manager.provider.SessionRead(sid) + return manager.provider.SessionRead(nil, sid) } } @@ -227,7 +228,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - session, err = manager.provider.SessionRead(sid) + session, err = manager.provider.SessionRead(nil, sid) if err != nil { return nil, err } @@ -269,7 +270,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { } sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) + manager.provider.SessionDestroy(nil, sid) if manager.config.EnableSetCookie { expiration := time.Now() cookie = &http.Cookie{Name: manager.config.CookieName, @@ -285,14 +286,14 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { // GetSessionStore Get SessionStore by its id. func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) + sessions, err = manager.provider.SessionRead(nil, sid) return } // GC Start session gc process. // it can do gc in times after gc lifetime. func (manager *Manager) GC() { - manager.provider.SessionGC() + manager.provider.SessionGC(nil) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } @@ -305,7 +306,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { //delete old cookie - session, _ = manager.provider.SessionRead(sid) + session, _ = manager.provider.SessionRead(nil, sid) cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", @@ -315,7 +316,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque } } else { oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) + session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid) cookie.Value = url.QueryEscape(sid) cookie.HttpOnly = true cookie.Path = "/" @@ -339,7 +340,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque // GetActiveSession Get all active sessions count number. func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() + return manager.provider.SessionAll(nil) } // SetSecure Set cookie with https. diff --git a/pkg/infrastructure/session/ssdb/sess_ssdb.go b/pkg/infrastructure/session/ssdb/sess_ssdb.go index 77d0c5c2..6e4f341e 100644 --- a/pkg/infrastructure/session/ssdb/sess_ssdb.go +++ b/pkg/infrastructure/session/ssdb/sess_ssdb.go @@ -1,6 +1,7 @@ package ssdb import ( + "context" "errors" "net/http" "strconv" @@ -31,7 +32,7 @@ func (p *Provider) connectInit() error { } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { p.maxLifetime = maxLifetime address := strings.Split(savePath, ":") p.host = address[0] @@ -44,7 +45,7 @@ func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { } // SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { +func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -68,7 +69,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) (bool, error) { +func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { return false, err @@ -85,7 +86,7 @@ func (p *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { //conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { @@ -118,7 +119,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { +func (p *Provider) SessionDestroy(ctx context.Context, sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { return err @@ -129,11 +130,11 @@ func (p *Provider) SessionDestroy(sid string) error { } // SessionGC not implemented -func (p *Provider) SessionGC() { +func (p *Provider) SessionGC(context.Context) { } // SessionAll not implemented -func (p *Provider) SessionAll() int { +func (p *Provider) SessionAll(context.Context) int { return 0 } @@ -147,7 +148,7 @@ type SessionStore struct { } // Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { s.lock.Lock() defer s.lock.Unlock() s.values[key] = value @@ -155,7 +156,7 @@ func (s *SessionStore) Set(key, value interface{}) error { } // Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() if value, ok := s.values[key]; ok { @@ -165,7 +166,7 @@ func (s *SessionStore) Get(key interface{}) interface{} { } // Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { s.lock.Lock() defer s.lock.Unlock() delete(s.values, key) @@ -173,7 +174,7 @@ func (s *SessionStore) Delete(key interface{}) error { } // Flush delete all keys and values -func (s *SessionStore) Flush() error { +func (s *SessionStore) Flush(context.Context) error { s.lock.Lock() defer s.lock.Unlock() s.values = make(map[interface{}]interface{}) @@ -181,12 +182,12 @@ func (s *SessionStore) Flush() error { } // SessionID return the sessionID -func (s *SessionStore) SessionID() string { +func (s *SessionStore) SessionID(context.Context) string { return s.sid } // SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { +func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(s.values) if err != nil { return diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index b8272f64..a6fec774 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -361,7 +361,7 @@ func (input *BeegoInput) Cookie(key string) string { // Session returns current session item value by a given key. // if non-existed, return nil. func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) + return input.CruSession.Get(nil, key) } // CopyBody returns the raw request body data as bytes. diff --git a/pkg/server/web/context/output.go b/pkg/server/web/context/output.go index 0a530244..a6e83681 100644 --- a/pkg/server/web/context/output.go +++ b/pkg/server/web/context/output.go @@ -404,5 +404,5 @@ func stringsToJSON(str string) string { // Session sets session item value with given key. func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) + output.Context.Input.CruSession.Set(nil, name, value) } diff --git a/pkg/server/web/controller.go b/pkg/server/web/controller.go index 6b71d617..2081e647 100644 --- a/pkg/server/web/controller.go +++ b/pkg/server/web/controller.go @@ -622,7 +622,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Set(name, value) + c.CruSession.Set(nil, name, value) } // GetSession gets value from session. @@ -630,7 +630,7 @@ func (c *Controller) GetSession(name interface{}) interface{} { if c.CruSession == nil { c.StartSession() } - return c.CruSession.Get(name) + return c.CruSession.Get(nil, name) } // DelSession removes value from session. @@ -638,14 +638,14 @@ func (c *Controller) DelSession(name interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Delete(name) + c.CruSession.Delete(nil, name) } // SessionRegenerateID regenerates session id for this session. // the session data have no changes. func (c *Controller) SessionRegenerateID() { if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter) } c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession @@ -653,7 +653,7 @@ func (c *Controller) SessionRegenerateID() { // DestroySession cleans session data and session cookie. func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession.Flush(nil) c.Ctx.Input.CruSession = nil GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) } diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index 9b70753e..c3eddd29 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -721,7 +721,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } defer func() { if ctx.Input.CruSession != nil { - ctx.Input.CruSession.SessionRelease(rw) + ctx.Input.CruSession.SessionRelease(nil, rw) } }() } From c0462f75bf5ad34f868b18e1320254da7f3cad04 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 30 Aug 2020 16:18:59 +0000 Subject: [PATCH 130/207] Add ctx to Task module API --- pkg/server/web/admin.go | 11 +++++---- pkg/task/task.go | 53 +++++++++++++++++++++-------------------- pkg/task/task_test.go | 4 ++-- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/pkg/server/web/admin.go b/pkg/server/web/admin.go index aace3d9e..f54ac9e5 100644 --- a/pkg/server/web/admin.go +++ b/pkg/server/web/admin.go @@ -16,6 +16,7 @@ package web import ( "bytes" + context2 "context" "encoding/json" "fmt" "net/http" @@ -378,10 +379,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { taskname := req.Form.Get("taskname") if taskname != "" { if t, ok := task.AdminTaskList[taskname]; ok { - if err := t.Run(); err != nil { + if err := t.Run(nil); err != nil { data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus(nil)))} } else { data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} } @@ -400,9 +401,9 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { for tname, tk := range task.AdminTaskList { result := []string{ template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec()), - template.HTMLEscapeString(tk.GetStatus()), - template.HTMLEscapeString(tk.GetPrev().String()), + template.HTMLEscapeString(tk.GetSpec(nil)), + template.HTMLEscapeString(tk.GetStatus(nil)), + template.HTMLEscapeString(tk.GetPrev(context2.Background()).String()), } *resultList = append(*resultList, result) } diff --git a/pkg/task/task.go b/pkg/task/task.go index 04185d8e..e2962000 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -15,6 +15,7 @@ package task import ( + "context" "log" "math" "sort" @@ -86,13 +87,13 @@ type TaskFunc func() error // Tasker task interface type Tasker interface { - GetSpec() string - GetStatus() string - Run() error - SetNext(time.Time) - GetNext() time.Time - SetPrev(time.Time) - GetPrev() time.Time + GetSpec(ctx context.Context) string + GetStatus(ctx context.Context) string + Run(ctx context.Context) error + SetNext(context.Context, time.Time) + GetNext(ctx context.Context) time.Time + SetPrev(context.Context, time.Time) + GetPrev(ctx context.Context) time.Time } // task error @@ -133,12 +134,12 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { } // GetSpec get spec string -func (t *Task) GetSpec() string { +func (t *Task) GetSpec(context.Context) string { return t.SpecStr } // GetStatus get current task status -func (t *Task) GetStatus() string { +func (t *Task) GetStatus(context.Context) string { var str string for _, v := range t.Errlist { str += v.t.String() + ":" + v.errinfo + "
" @@ -147,7 +148,7 @@ func (t *Task) GetStatus() string { } // Run run all tasks -func (t *Task) Run() error { +func (t *Task) Run(context.Context) error { err := t.DoFunc() if err != nil { index := t.errCnt % t.ErrLimit @@ -158,22 +159,22 @@ func (t *Task) Run() error { } // SetNext set next time for this task -func (t *Task) SetNext(now time.Time) { +func (t *Task) SetNext(ctx context.Context, now time.Time) { t.Next = t.Spec.Next(now) } // GetNext get the next call time of this task -func (t *Task) GetNext() time.Time { +func (t *Task) GetNext(context.Context) time.Time { return t.Next } // SetPrev set prev time of this task -func (t *Task) SetPrev(now time.Time) { +func (t *Task) SetPrev(ctx context.Context, now time.Time) { t.Prev = now } // GetPrev get prev time of this task -func (t *Task) GetPrev() time.Time { +func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } @@ -410,7 +411,7 @@ func StartTask() { func run() { now := time.Now().Local() for _, t := range AdminTaskList { - t.SetNext(now) + t.SetNext(nil, now) } for { @@ -420,30 +421,30 @@ func run() { taskLock.RUnlock() sortList.Sort() var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext(context.Background()).IsZero() { // If there are no entries yet, just sleep - it still handles new entries // and stop requests. effective = now.AddDate(10, 0, 0) } else { - effective = sortList.Vals[0].GetNext() + effective = sortList.Vals[0].GetNext(context.Background()) } select { case now = <-time.After(effective.Sub(now)): // Run every entry whose next time was this effective time. for _, e := range sortList.Vals { - if e.GetNext() != effective { + if e.GetNext(context.Background()) != effective { break } - go e.Run() - e.SetPrev(e.GetNext()) - e.SetNext(effective) + go e.Run(nil) + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(nil, effective) } continue case <-changed: now = time.Now().Local() taskLock.Lock() for _, t := range AdminTaskList { - t.SetNext(now) + t.SetNext(nil, now) } taskLock.Unlock() continue @@ -468,7 +469,7 @@ func StopTask() { func AddTask(taskname string, t Tasker) { taskLock.Lock() defer taskLock.Unlock() - t.SetNext(time.Now().Local()) + t.SetNext(nil, time.Now().Local()) AdminTaskList[taskname] = t if isstart { changed <- true @@ -511,13 +512,13 @@ func (ms *MapSorter) Sort() { func (ms *MapSorter) Len() int { return len(ms.Keys) } func (ms *MapSorter) Less(i, j int) bool { - if ms.Vals[i].GetNext().IsZero() { + if ms.Vals[i].GetNext(context.Background()).IsZero() { return false } - if ms.Vals[j].GetNext().IsZero() { + if ms.Vals[j].GetNext(context.Background()).IsZero() { return true } - return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) } func (ms *MapSorter) Swap(i, j int) { ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index c7360b39..9f73ce46 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -26,7 +26,7 @@ import ( func TestParse(t *testing.T) { tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) - err := tk.Run() + err := tk.Run(nil) if err != nil { t.Fatal(err) } @@ -65,7 +65,7 @@ func TestTask_Run(t *testing.T) { } tk := NewTask("taska", "0/30 * * * * *", task) for i := 0; i < 200; i++ { - e := tk.Run() + e := tk.Run(nil) assert.NotNil(t, e) } From f4f200cf0407bb2ce4ec742263450c8f4849fe24 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 31 Aug 2020 13:02:22 +0000 Subject: [PATCH 131/207] enhance yaml --- pkg/infrastructure/config/ini.go | 4 ++ pkg/infrastructure/config/yaml/yaml.go | 60 ++++++++++++++++++++- pkg/infrastructure/config/yaml/yaml_test.go | 35 ++++++++++++ 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/pkg/infrastructure/config/ini.go b/pkg/infrastructure/config/ini.go index 92ed8df8..cc67e4cd 100644 --- a/pkg/infrastructure/config/ini.go +++ b/pkg/infrastructure/config/ini.go @@ -66,6 +66,10 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e keyComment: make(map[string]string), RWMutex: sync.RWMutex{}, } + + cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return cfg.getdata(key), nil + }) cfg.Lock() defer cfg.Unlock() diff --git a/pkg/infrastructure/config/yaml/yaml.go b/pkg/infrastructure/config/yaml/yaml.go index 1f4f1d23..61ea45b9 100644 --- a/pkg/infrastructure/config/yaml/yaml.go +++ b/pkg/infrastructure/config/yaml/yaml.go @@ -42,8 +42,10 @@ import ( "sync" "github.com/beego/goyaml2" + "gopkg.in/yaml.v2" "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // Config is a yaml config parser and implements Config interface. @@ -120,11 +122,61 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { // ConfigContainer is a config which represents the yaml configuration. type ConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.RWMutex } +// Unmarshaler is similar to Sub +func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + + bytes, err := yaml.Marshal(sub) + if err != nil { + return err + } + return yaml.Unmarshal(bytes, obj) +} + +func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + return &ConfigContainer{ + data: sub, + }, nil +} + +func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + tmpData := c.data + keys := strings.Split(key, ".") + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys)-1 { + return tmpData, nil + } + } + default: + return nil, errors.New(fmt.Sprintf("the key is invalid: %s", key)) + } + } + } + + return tmpData, nil +} + +func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing + logs.Warn("Unsupported operation: OnChange") +} + // Bool returns the boolean value for a given key. func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { v, err := c.getData(key) @@ -291,7 +343,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { c.RLock() defer c.RUnlock() - keys := strings.Split(key, ".") + keys := strings.Split(c.key(key), ".") tmpData := c.data for idx, k := range keys { if v, ok := tmpData[k]; ok { @@ -314,6 +366,10 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { return nil, fmt.Errorf("not exist key %q", key) } +func (c *ConfigContainer) key(key string) string { + return key +} + func init() { config.Register("yaml", &Config{}) } diff --git a/pkg/infrastructure/config/yaml/yaml_test.go b/pkg/infrastructure/config/yaml/yaml_test.go index 197b68e4..1fd4e894 100644 --- a/pkg/infrastructure/config/yaml/yaml_test.go +++ b/pkg/infrastructure/config/yaml/yaml_test.go @@ -15,10 +15,13 @@ package yaml import ( + "context" "fmt" "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/infrastructure/config" ) @@ -37,6 +40,9 @@ func TestYaml(t *testing.T) { "path1": ${GOPATH} "path2": ${GOPATH||/home/go} "empty": "" +"user": + "name": "tom" + "age": 13 ` keyValue = map[string]interface{}{ @@ -114,4 +120,33 @@ func TestYaml(t *testing.T) { t.Fatal("get name error") } + sub, err := yamlconf.Sub(context.Background(), "user") + assert.Nil(t, err) + assert.NotNil(t, sub) + name, err := sub.String(context.Background(), "name") + assert.Nil(t, err) + assert.Equal(t, "tom", name) + + age, err := sub.Int(context.Background(), "age") + assert.Nil(t, err) + assert.Equal(t, 13, age) + + user := &User{} + + err = sub.Unmarshaler(context.Background(), "", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) + + user = &User{} + + err = yamlconf.Unmarshaler(context.Background(), "user", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) +} + +type User struct { + Name string `yaml:"name"` + Age int `yaml:"age"` } From 087399c44a967029781a3a8ce682837c4369b95e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 31 Aug 2020 13:57:26 +0000 Subject: [PATCH 132/207] support xml --- pkg/infrastructure/config/xml/xml.go | 52 +++++++++++++++++++++-- pkg/infrastructure/config/xml/xml_test.go | 33 +++++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/pkg/infrastructure/config/xml/xml.go b/pkg/infrastructure/config/xml/xml.go index e3e93b01..e5096b9b 100644 --- a/pkg/infrastructure/config/xml/xml.go +++ b/pkg/infrastructure/config/xml/xml.go @@ -26,7 +26,7 @@ // // cnf, err := config.NewConfig("xml", "config.xml") // -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package xml import ( @@ -40,7 +40,11 @@ import ( "strings" "sync" + "github.com/mitchellh/mapstructure" + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/beego/x2j" ) @@ -75,11 +79,53 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { // ConfigContainer is a Config which represents the xml configuration. type ConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.Mutex } +// Unmarshaler is a little be inconvenient since the xml library doesn't know type. +// So when you use +// 1 +// The "1" is a string, not int +func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: sub, + }, nil + +} + +func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, errors.New(fmt.Sprintf("the key is not found: %s", key)) + } + res, ok := value.(map[string]interface{}) + if !ok { + return nil, errors.New(fmt.Sprintf("the value of this key is not a structure: %s", key)) + } + return res, nil +} + +func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + logs.Warn("Unsupported operation") +} + // Bool returns the boolean value for a given key. func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { if v := c.data[key]; v != nil { @@ -155,7 +201,7 @@ func (c *ConfigContainer) String(ctx context.Context, key string) (string, error // DefaultString returns the string value for a given key. // if err != nil return defaultVal func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { - v, err := c.String(nil, key) + v, err := c.String(ctx, key) if v == "" || err != nil { return defaultVal } diff --git a/pkg/infrastructure/config/xml/xml_test.go b/pkg/infrastructure/config/xml/xml_test.go index 470280e0..0a3eb313 100644 --- a/pkg/infrastructure/config/xml/xml_test.go +++ b/pkg/infrastructure/config/xml/xml_test.go @@ -15,10 +15,13 @@ package xml import ( + "context" "fmt" "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/infrastructure/config" ) @@ -120,8 +123,36 @@ func TestXML(t *testing.T) { t.Fatal(err) } - res, _ := xmlconf.String(nil, "name") + res, _ := xmlconf.String(context.Background(), "name") if res != "astaxie" { t.Fatal("get name error") } + + sub, err := xmlconf.Sub(context.Background(), "mysection") + assert.Nil(t, err) + assert.NotNil(t, sub) + name, err := sub.String(context.Background(), "name") + assert.Nil(t, err) + assert.Equal(t, "MySection", name) + + id, err := sub.Int(context.Background(), "id") + assert.Nil(t, err) + assert.Equal(t, 1, id) + + sec := &Section{} + + err = sub.Unmarshaler(context.Background(), "", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) + + sec = &Section{} + + err = xmlconf.Unmarshaler(context.Background(), "mysection", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) + +} + +type Section struct { + Name string `xml:"name"` } From 33b052bc7a8f46302101fe1074ea66a956df7224 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 31 Aug 2020 14:14:31 +0000 Subject: [PATCH 133/207] support json --- pkg/infrastructure/config/json/json.go | 42 ++++++++++++++++++++- pkg/infrastructure/config/json/json_test.go | 26 +++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/pkg/infrastructure/config/json/json.go b/pkg/infrastructure/config/json/json.go index dae55118..c65eff4d 100644 --- a/pkg/infrastructure/config/json/json.go +++ b/pkg/infrastructure/config/json/json.go @@ -25,7 +25,10 @@ import ( "strings" "sync" + "github.com/mitchellh/mapstructure" + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // JSONConfig is a json config parser and implements Config interface. @@ -70,11 +73,48 @@ func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { // JSONConfigContainer is a config which represents the json configuration. // Only when get value, support key as section:name type. type JSONConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.RWMutex } +func (c *JSONConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *JSONConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + return &JSONConfigContainer{ + data: sub, + }, nil +} + +func (c *JSONConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, errors.New(fmt.Sprintf("key is not found: %s", key)) + } + + res, ok := value.(map[string]interface{}) + if !ok { + return nil, errors.New(fmt.Sprintf("the type of value is invalid, key: %s", key)) + } + return res, nil +} + +func (c *JSONConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + logs.Warn("unsupported operation") +} + // Bool returns the boolean value for a given key. func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error) { val := c.getData(key) diff --git a/pkg/infrastructure/config/json/json_test.go b/pkg/infrastructure/config/json/json_test.go index 486d2b11..5275ee57 100644 --- a/pkg/infrastructure/config/json/json_test.go +++ b/pkg/infrastructure/config/json/json_test.go @@ -15,10 +15,13 @@ package json import ( + "context" "fmt" "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/infrastructure/config" ) @@ -223,4 +226,27 @@ func TestJson(t *testing.T) { if !jsonconf.DefaultBool(nil, "unknown", true) { t.Error("unknown keys with default value wrong") } + + sub, err := jsonconf.Sub(context.Background(), "database") + assert.Nil(t, err) + assert.NotNil(t, sub) + + sub, err = sub.Sub(context.Background(), "conns") + assert.Nil(t, err) + + maxCon, _ := sub.Int(context.Background(), "maxconnection") + assert.Equal(t, 12, maxCon) + + dbCfg := &DatabaseConfig{} + err = sub.Unmarshaler(context.Background(), "", dbCfg) + assert.Nil(t, err) + assert.Equal(t, 12, dbCfg.MaxConnection) + assert.True(t, dbCfg.Autoconnect) + assert.Equal(t, "info", dbCfg.Connectioninfo) +} + +type DatabaseConfig struct { + MaxConnection int `json:"maxconnection"` + Autoconnect bool `json:"autoconnect"` + Connectioninfo string `json:"connectioninfo"` } From 185d55eb4638c6432a73d7ae94b3e3654ce134e6 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 1 Sep 2020 21:25:29 +0800 Subject: [PATCH 134/207] adapt config --- pkg/adapter/config/adapter.go | 193 +++++++++++++++++++++++ pkg/adapter/config/config.go | 151 ++++++++++++++++++ pkg/adapter/config/config_test.go | 55 +++++++ pkg/adapter/config/env/env.go | 50 ++++++ pkg/adapter/config/env/env_test.go | 75 +++++++++ pkg/adapter/config/fake.go | 25 +++ pkg/adapter/config/ini_test.go | 190 +++++++++++++++++++++++ pkg/adapter/config/json.go | 19 +++ pkg/adapter/config/json_test.go | 222 +++++++++++++++++++++++++++ pkg/adapter/config/xml/xml.go | 34 ++++ pkg/adapter/config/xml/xml_test.go | 125 +++++++++++++++ pkg/adapter/config/yaml/yaml.go | 34 ++++ pkg/adapter/config/yaml/yaml_test.go | 115 ++++++++++++++ 13 files changed, 1288 insertions(+) create mode 100644 pkg/adapter/config/adapter.go create mode 100644 pkg/adapter/config/config.go create mode 100644 pkg/adapter/config/config_test.go create mode 100644 pkg/adapter/config/env/env.go create mode 100644 pkg/adapter/config/env/env_test.go create mode 100644 pkg/adapter/config/fake.go create mode 100644 pkg/adapter/config/ini_test.go create mode 100644 pkg/adapter/config/json.go create mode 100644 pkg/adapter/config/json_test.go create mode 100644 pkg/adapter/config/xml/xml.go create mode 100644 pkg/adapter/config/xml/xml_test.go create mode 100644 pkg/adapter/config/yaml/yaml.go create mode 100644 pkg/adapter/config/yaml/yaml_test.go diff --git a/pkg/adapter/config/adapter.go b/pkg/adapter/config/adapter.go new file mode 100644 index 00000000..f74b3ff9 --- /dev/null +++ b/pkg/adapter/config/adapter.go @@ -0,0 +1,193 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +type newToOldConfigerAdapter struct { + delegate config.Configer +} + +func (c *newToOldConfigerAdapter) Set(key, val string) error { + return c.delegate.Set(context.Background(), key, val) +} + +func (c *newToOldConfigerAdapter) String(key string) string { + res, _ := c.delegate.String(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Strings(key string) []string { + res, _ := c.delegate.Strings(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Int(key string) (int, error) { + return c.delegate.Int(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Int64(key string) (int64, error) { + return c.delegate.Int64(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Bool(key string) (bool, error) { + return c.delegate.Bool(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Float(key string) (float64, error) { + return c.delegate.Float(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) DefaultString(key string, defaultVal string) string { + return c.delegate.DefaultString(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { + return c.delegate.DefaultStrings(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt(key string, defaultVal int) int { + return c.delegate.DefaultInt(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { + return c.delegate.DefaultInt64(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { + return c.delegate.DefaultBool(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { + return c.delegate.DefaultFloat(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DIY(key string) (interface{}, error) { + return c.delegate.DIY(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) GetSection(section string) (map[string]string, error) { + return c.delegate.GetSection(context.Background(), section) +} + +func (c *newToOldConfigerAdapter) SaveConfigFile(filename string) error { + return c.delegate.SaveConfigFile(context.Background(), filename) +} + +type oldToNewConfigerAdapter struct { + delegate Configer +} + +func (o *oldToNewConfigerAdapter) Set(ctx context.Context, key, val string) error { + return o.delegate.Set(key, val) +} + +func (o *oldToNewConfigerAdapter) String(ctx context.Context, key string) (string, error) { + return o.delegate.String(key), nil +} + +func (o *oldToNewConfigerAdapter) Strings(ctx context.Context, key string) ([]string, error) { + return o.delegate.Strings(key), nil +} + +func (o *oldToNewConfigerAdapter) Int(ctx context.Context, key string) (int, error) { + return o.delegate.Int(key) +} + +func (o *oldToNewConfigerAdapter) Int64(ctx context.Context, key string) (int64, error) { + return o.delegate.Int64(key) +} + +func (o *oldToNewConfigerAdapter) Bool(ctx context.Context, key string) (bool, error) { + return o.delegate.Bool(key) +} + +func (o *oldToNewConfigerAdapter) Float(ctx context.Context, key string) (float64, error) { + return o.delegate.Float(key) +} + +func (o *oldToNewConfigerAdapter) DefaultString(ctx context.Context, key string, defaultVal string) string { + return o.delegate.DefaultString(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + return o.delegate.DefaultStrings(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt(ctx context.Context, key string, defaultVal int) int { + return o.delegate.DefaultInt(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + return o.delegate.DefaultInt64(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + return o.delegate.DefaultBool(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + return o.delegate.DefaultFloat(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DIY(ctx context.Context, key string) (interface{}, error) { + return o.delegate.DIY(key) +} + +func (o *oldToNewConfigerAdapter) GetSection(ctx context.Context, section string) (map[string]string, error) { + return o.delegate.GetSection(section) +} + +func (o *oldToNewConfigerAdapter) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + return errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) Sub(ctx context.Context, key string) (config.Configer, error) { + return nil, errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing +} + +func (o *oldToNewConfigerAdapter) SaveConfigFile(ctx context.Context, filename string) error { + return o.delegate.SaveConfigFile(filename) +} + +type oldToNewConfigAdapter struct { + delegate Config +} + +func (o *oldToNewConfigAdapter) Parse(key string) (config.Configer, error) { + old, err := o.delegate.Parse(key) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} + +func (o *oldToNewConfigAdapter) ParseData(data []byte) (config.Configer, error) { + old, err := o.delegate.ParseData(data) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} diff --git a/pkg/adapter/config/config.go b/pkg/adapter/config/config.go new file mode 100644 index 00000000..c870a15a --- /dev/null +++ b/pkg/adapter/config/config.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +// Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +// More docs http://beego.me/docs/module/config.md +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error // support section::key type in given key when using ini type. + String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string // get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string // get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + config.Register(name, &oldToNewConfigAdapter{delegate: adapter}) +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + cfg, err := config.NewConfig(adapterName, filename) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + cfg, err := config.NewConfigData(adapterName, data) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + return config.ExpandValueEnvForMap(m) +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) string { + return config.ExpandValueEnv(value) +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + return config.ParseBool(val) +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + return config.ToString(x) +} diff --git a/pkg/adapter/config/config_test.go b/pkg/adapter/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/adapter/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/adapter/config/env/env.go b/pkg/adapter/config/env/env.go new file mode 100644 index 00000000..77d7b53c --- /dev/null +++ b/pkg/adapter/config/env/env.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config/env" +) + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + return env.Get(key, defVal) +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + return env.MustGet(key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + return env.MustSet(key, value) +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + return env.GetAll() +} diff --git a/pkg/adapter/config/env/env_test.go b/pkg/adapter/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/adapter/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/adapter/config/fake.go b/pkg/adapter/config/fake.go new file mode 100644 index 00000000..fac96b41 --- /dev/null +++ b/pkg/adapter/config/fake.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + new := config.NewFakeConfig() + return &newToOldConfigerAdapter{delegate: new} +} diff --git a/pkg/adapter/config/ini_test.go b/pkg/adapter/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/adapter/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/adapter/config/json.go b/pkg/adapter/config/json.go new file mode 100644 index 00000000..d0fe4d09 --- /dev/null +++ b/pkg/adapter/config/json.go @@ -0,0 +1,19 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/json" +) diff --git a/pkg/adapter/config/json_test.go b/pkg/adapter/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/adapter/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/adapter/config/xml/xml.go b/pkg/adapter/config/xml/xml.go new file mode 100644 index 00000000..f96cdcd6 --- /dev/null +++ b/pkg/adapter/config/xml/xml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +// More docs http://beego.me/docs/module/config.md +package xml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/xml" +) diff --git a/pkg/adapter/config/xml/xml_test.go b/pkg/adapter/config/xml/xml_test.go new file mode 100644 index 00000000..122c5027 --- /dev/null +++ b/pkg/adapter/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/adapter/config/yaml/yaml.go b/pkg/adapter/config/yaml/yaml.go new file mode 100644 index 00000000..bc2398e9 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +// More docs http://beego.me/docs/module/config.md +package yaml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/yaml" +) diff --git a/pkg/adapter/config/yaml/yaml_test.go b/pkg/adapter/config/yaml/yaml_test.go new file mode 100644 index 00000000..e4e309a2 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/pkg/adapter/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} From e54dbabf0b6311847a1e6af92af6deac8014b965 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Tue, 1 Sep 2020 21:56:48 +0800 Subject: [PATCH 135/207] movement for global modelCache --- pkg/client/orm/cmd.go | 22 +- pkg/client/orm/cmd_utils.go | 156 ------------ pkg/client/orm/models.go | 456 ++++++++++++++++++++++++++++++++++ pkg/client/orm/models_boot.go | 310 +---------------------- 4 files changed, 477 insertions(+), 467 deletions(-) diff --git a/pkg/client/orm/cmd.go b/pkg/client/orm/cmd.go index f03382e9..e03fc0ee 100644 --- a/pkg/client/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -99,13 +99,17 @@ func (d *commandSyncDb) Parse(args []string) { // Run orm line command. func (d *commandSyncDb) Run() error { var drops []string + var err error if d.force { - drops = getDbDropSQL(d.al) + drops, err = modelCache.getDbDropSQL(d.al) + if err != nil { + return err + } } db := d.al.DB - if d.force { + if d.force && len(drops) > 0 { for i, mi := range modelCache.allOrdered() { query := drops[i] if !d.noInfo { @@ -124,7 +128,10 @@ func (d *commandSyncDb) Run() error { } } - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } tables, err := d.al.DbBaser.GetTables(db) if err != nil { @@ -201,7 +208,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf("create table `%s` \n", mi.table) } - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } @@ -245,10 +252,13 @@ func (d *commandSQLAll) Parse(args []string) { // Run orm line command. func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } var all []string for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go index e045e847..8d6c0c33 100644 --- a/pkg/client/orm/cmd_utils.go +++ b/pkg/client/orm/cmd_utils.go @@ -16,7 +16,6 @@ package orm import ( "fmt" - "os" "strings" ) @@ -26,21 +25,6 @@ type dbIndex struct { SQL string } -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - // get database column type string. func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() @@ -140,146 +124,6 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { ) } -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands func getColumnDefault(fi *fieldInfo) string { var ( diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index c8fbcced..a7de10f7 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -15,7 +15,11 @@ package orm import ( + "errors" + "fmt" "reflect" + "runtime/debug" + "strings" "sync" ) @@ -95,12 +99,464 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { // clean all model info. func (mc *_modelCache) clean() { + mc.Lock() + defer mc.Unlock() + mc.orders = make([]string, 0) mc.cache = make(map[string]*modelInfo) mc.cacheByFullName = make(map[string]*modelInfo) mc.done = false } +//bootstrap bootstrap for models +func (mc *_modelCache) bootstrap() { + mc.Lock() + defer mc.Unlock() + if mc.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := mc.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := mc.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := mc.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + } + modelCache.done = true + return +} + +// register register models to model cache +func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { + if mc.done { + err = fmt.Errorf("register must be run before BootStrap") + return + } + + for _, model := range models { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + err = fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)) + return + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) + return + } + + table := getTableName(val) + + if prefixOrSuffixStr != "" { + if prefixOrSuffix { + table = prefixOrSuffixStr + table + } else { + table = table + prefixOrSuffixStr + } + } + + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := mc.getByFullName(name); ok { + err = fmt.Errorf(" model `%s` repeat register, must be unique\n", name) + return + } + + if _, ok := mc.get(table); ok { + err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) + return + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + return + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + mc.set(table, mi) + } + return +} + +//getDbDropSQL get database scheme drop sql queries +func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return queries,nil +} + +//getDbCreateSQL get database scheme creation sql queries +func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + queries = append(queries, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + // ResetModelCache Clean model cache. Then you can re-RegisterModel. // Common use this api for test case. func ResetModelCache() { diff --git a/pkg/client/orm/models_boot.go b/pkg/client/orm/models_boot.go index 8c56b3c4..407cf536 100644 --- a/pkg/client/orm/models_boot.go +++ b/pkg/client/orm/models_boot.go @@ -16,294 +16,8 @@ package orm import ( "fmt" - "os" - "reflect" - "runtime/debug" - "strings" ) -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - debug.PrintStack() - os.Exit(2) - } -} - // RegisterModel register models func RegisterModel(models ...interface{}) { if modelCache.done { @@ -314,34 +28,20 @@ func RegisterModel(models ...interface{}) { // RegisterModelWithPrefix register models with a prefix func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) + if err := modelCache.register(prefix, true, models...); err != nil { + panic(err) } } // RegisterModelWithSuffix register models with a suffix func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) + if err := modelCache.register(suffix, false, models...); err != nil { + panic(err) } } // BootStrap bootstrap models. // make all model parsed and can not add more models func BootStrap() { - modelCache.Lock() - defer modelCache.Unlock() - if modelCache.done { - return - } - bootStrap() - modelCache.done = true + modelCache.bootstrap() } From 78d91062c911d53ec9bc5ce278ce5dee14020e15 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 1 Sep 2020 22:16:49 +0800 Subject: [PATCH 136/207] Adapt new API to old API: httplib --- pkg/adapter/httplib/httplib.go | 300 ++++++++++++++++++++++++++++ pkg/adapter/httplib/httplib_test.go | 286 ++++++++++++++++++++++++++ 2 files changed, 586 insertions(+) create mode 100644 pkg/adapter/httplib/httplib.go create mode 100644 pkg/adapter/httplib/httplib_test.go diff --git a/pkg/adapter/httplib/httplib.go b/pkg/adapter/httplib/httplib.go new file mode 100644 index 00000000..d2ef36c1 --- /dev/null +++ b/pkg/adapter/httplib/httplib.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/client/httplib" +) + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + httplib.SetDefaultSetting(httplib.BeegoHTTPSettings(setting)) +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + return &BeegoHTTPRequest{ + delegate: httplib.NewBeegoRequest(rawurl, method), + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings httplib.BeegoHTTPSettings + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + delegate *httplib.BeegoHTTPRequest +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.delegate.GetRequest() +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.delegate.Setting(httplib.BeegoHTTPSettings(setting)) + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.delegate.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.delegate.SetEnableCookie(enable) + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.delegate.SetUserAgent(useragent) + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.delegate.Debug(isdebug) + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.delegate.Retries(times) + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.delegate.RetryDelay(delay) + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.delegate.DumpBody(isdump) + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.delegate.DumpRequest() +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.delegate.SetTimeout(connectTimeout, readWriteTimeout) + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.delegate.SetTLSClientConfig(config) + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.delegate.Header(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.delegate.SetHost(host) + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + b.delegate.SetProtocolVersion(vers) + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.delegate.SetCookie(cookie) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.delegate.SetTransport(transport) + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.delegate.SetProxy(proxy) + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.delegate.SetCheckRedirect(redirect) + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + b.delegate.Param(key, value) + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.delegate.PostFile(formname, filename) + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + b.delegate.Body(data) + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.XMLBody(obj) + return b, err +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.YAMLBody(obj) + return b, err +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.JSONBody(obj) + return b, err +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + return b.delegate.DoRequest() +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + return b.delegate.String() +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + return b.delegate.Bytes() +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + return b.delegate.ToFile(filename) +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + return b.delegate.ToJSON(v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + return b.delegate.ToXML(v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + return b.delegate.ToYAML(v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.delegate.Response() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return httplib.TimeoutDialer(cTimeout, rwTimeout) +} diff --git a/pkg/adapter/httplib/httplib_test.go b/pkg/adapter/httplib/httplib_test.go new file mode 100644 index 00000000..e7605c87 --- /dev/null +++ b/pkg/adapter/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.SetCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + }) + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +// func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +// } + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} From 7574b91760309df810c6b506ff1d9fda877736d4 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Wed, 2 Sep 2020 00:26:25 +0800 Subject: [PATCH 137/207] add type modelRegister interface into Ormer --- pkg/client/orm/do_nothing_orm.go | 12 ++++++++++++ pkg/client/orm/filter_orm_decorator.go | 13 +++++++++++++ pkg/client/orm/models.go | 23 +++++++++++++++++++++++ pkg/client/orm/models_boot.go | 7 ------- pkg/client/orm/orm.go | 13 +++++++++++++ pkg/client/orm/types.go | 1 + 6 files changed, 62 insertions(+), 7 deletions(-) diff --git a/pkg/client/orm/do_nothing_orm.go b/pkg/client/orm/do_nothing_orm.go index e27e7f3a..07c7fd74 100644 --- a/pkg/client/orm/do_nothing_orm.go +++ b/pkg/client/orm/do_nothing_orm.go @@ -30,6 +30,18 @@ var _ Ormer = new(DoNothingOrm) type DoNothingOrm struct { } +func (d *DoNothingOrm) RegisterModels(models ...interface{}) (err error) { + return nil +} + +func (d *DoNothingOrm) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { + return nil +} + +func (d *DoNothingOrm) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { + return nil +} + func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { return nil } diff --git a/pkg/client/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go index d0c5c537..095c8485 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -32,6 +32,7 @@ var _ TxOrmer = new(filterOrmDecorator) type filterOrmDecorator struct { ormer + modelRegister TxBeginner TxCommitter @@ -42,6 +43,18 @@ type filterOrmDecorator struct { txName string } +func (f *filterOrmDecorator) RegisterModels(models ...interface{}) (err error) { + return f.modelRegister.RegisterModels(models...) +} + +func (f *filterOrmDecorator) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { + return f.modelRegister.RegisterModelsWithPrefix(prefix, models...) +} + +func (f *filterOrmDecorator) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { + return f.modelRegister.RegisterModelsWithSuffix(suffix, models...) +} + func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { res := &filterOrmDecorator{ ormer: delegate, diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index a7de10f7..97faa00a 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -39,6 +39,15 @@ var ( } ) +type modelRegister interface { + //RegisterModels register models without prefix or suffix + RegisterModels(models ...interface{}) (err error) + //RegisterModelsWithPrefix register models with prefix + RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) + //RegisterModelsWithSuffix register models with suffix + RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) +} + // model info collection type _modelCache struct { sync.RWMutex // only used outsite for bootStrap @@ -48,6 +57,20 @@ type _modelCache struct { done bool } +var _ modelRegister = new(_modelCache) + +func (mc *_modelCache) RegisterModels(models ...interface{}) (err error) { + return mc.register(``, true, models...) +} + +func (mc *_modelCache) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { + return mc.register(prefix, true, models...) +} + +func (mc *_modelCache) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { + return mc.register(suffix, false, models...) +} + // get all model info func (mc *_modelCache) all() map[string]*modelInfo { m := make(map[string]*modelInfo, len(mc.cache)) diff --git a/pkg/client/orm/models_boot.go b/pkg/client/orm/models_boot.go index 407cf536..9a0ce893 100644 --- a/pkg/client/orm/models_boot.go +++ b/pkg/client/orm/models_boot.go @@ -14,15 +14,8 @@ package orm -import ( - "fmt" -) - // RegisterModel register models func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } RegisterModelWithPrefix("", models...) } diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 634b1892..d82f7e05 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -498,10 +498,23 @@ func (o *ormBase) DBStats() *sql.DBStats { type orm struct { ormBase + modelRegister } var _ Ormer = new(orm) +func (o *orm) RegisterModels(models ...interface{}) (err error) { + return o.modelRegister.RegisterModels(models) +} + +func (o *orm) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { + return o.modelRegister.RegisterModelsWithPrefix(prefix, models...) +} + +func (o *orm) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { + return o.modelRegister.RegisterModelsWithSuffix(suffix, models...) +} + func (o *orm) Begin() (TxOrmer, error) { return o.BeginWithCtx(context.Background()) } diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index eb34e759..584f0f8a 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -214,6 +214,7 @@ type ormer interface { type Ormer interface { ormer + modelRegister TxBeginner } From 7a53baaf9b4badb2572abb4e081210a17cee1d68 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Wed, 2 Sep 2020 00:33:46 +0800 Subject: [PATCH 138/207] rename modelRegister to modelCacheHandler --- pkg/client/orm/filter_orm_decorator.go | 8 ++++---- pkg/client/orm/models.go | 17 +++++++++++------ pkg/client/orm/orm.go | 10 ++++++---- pkg/client/orm/types.go | 2 +- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/pkg/client/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go index 095c8485..5a49e395 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -32,7 +32,7 @@ var _ TxOrmer = new(filterOrmDecorator) type filterOrmDecorator struct { ormer - modelRegister + modelCacheHandler TxBeginner TxCommitter @@ -44,15 +44,15 @@ type filterOrmDecorator struct { } func (f *filterOrmDecorator) RegisterModels(models ...interface{}) (err error) { - return f.modelRegister.RegisterModels(models...) + return f.modelCacheHandler.RegisterModels(models...) } func (f *filterOrmDecorator) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return f.modelRegister.RegisterModelsWithPrefix(prefix, models...) + return f.modelCacheHandler.RegisterModelsWithPrefix(prefix, models...) } func (f *filterOrmDecorator) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return f.modelRegister.RegisterModelsWithSuffix(suffix, models...) + return f.modelCacheHandler.RegisterModelsWithSuffix(suffix, models...) } func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index 97faa00a..55ba5a73 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -33,13 +33,10 @@ const ( ) var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } + modelCache = NewModelCacheHandler() ) -type modelRegister interface { +type modelCacheHandler interface { //RegisterModels register models without prefix or suffix RegisterModels(models ...interface{}) (err error) //RegisterModelsWithPrefix register models with prefix @@ -57,7 +54,15 @@ type _modelCache struct { done bool } -var _ modelRegister = new(_modelCache) +//NewModelCacheHandler generator of _modelCache +func NewModelCacheHandler() *_modelCache { + return &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +} + +var _ modelCacheHandler = new(_modelCache) func (mc *_modelCache) RegisterModels(models ...interface{}) (err error) { return mc.register(``, true, models...) diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index d82f7e05..a18dae3c 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -498,21 +498,21 @@ func (o *ormBase) DBStats() *sql.DBStats { type orm struct { ormBase - modelRegister + modelCacheHandler } var _ Ormer = new(orm) func (o *orm) RegisterModels(models ...interface{}) (err error) { - return o.modelRegister.RegisterModels(models) + return o.modelCacheHandler.RegisterModels(models) } func (o *orm) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return o.modelRegister.RegisterModelsWithPrefix(prefix, models...) + return o.modelCacheHandler.RegisterModelsWithPrefix(prefix, models...) } func (o *orm) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return o.modelRegister.RegisterModelsWithSuffix(suffix, models...) + return o.modelCacheHandler.RegisterModelsWithSuffix(suffix, models...) } func (o *orm) Begin() (TxOrmer, error) { @@ -635,6 +635,8 @@ func newDBWithAlias(al *alias) Ormer { o.db = al.DB } + o.modelCacheHandler = NewModelCacheHandler() + if len(globalFilterChains) > 0 { return NewFilterOrmDecorator(o, globalFilterChains...) } diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index 584f0f8a..cee570af 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -214,7 +214,7 @@ type ormer interface { type Ormer interface { ormer - modelRegister + modelCacheHandler TxBeginner } From 3bf5cde38c840383b75ab7873fdb062aa2abe7ad Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 20:36:53 +0800 Subject: [PATCH 139/207] adapt context --- pkg/adapter/context/acceptencoder.go | 45 +++++ pkg/adapter/context/context.go | 146 ++++++++++++++ pkg/adapter/context/input.go | 282 +++++++++++++++++++++++++++ pkg/adapter/context/output.go | 154 +++++++++++++++ pkg/adapter/context/renderer.go | 9 + pkg/adapter/context/response.go | 26 +++ 6 files changed, 662 insertions(+) create mode 100644 pkg/adapter/context/acceptencoder.go create mode 100644 pkg/adapter/context/context.go create mode 100644 pkg/adapter/context/input.go create mode 100644 pkg/adapter/context/output.go create mode 100644 pkg/adapter/context/renderer.go create mode 100644 pkg/adapter/context/response.go diff --git a/pkg/adapter/context/acceptencoder.go b/pkg/adapter/context/acceptencoder.go new file mode 100644 index 00000000..e578de45 --- /dev/null +++ b/pkg/adapter/context/acceptencoder.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + context.InitGzip(minLength, compressLevel, methods) +} + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return context.WriteFile(encoding, writer, file) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return context.WriteBody(encoding, writer, content) +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + return context.ParseEncoding(r) +} diff --git a/pkg/adapter/context/context.go b/pkg/adapter/context/context.go new file mode 100644 index 00000000..f9d8c624 --- /dev/null +++ b/pkg/adapter/context/context.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "net" + "net/http" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// commonly used mime-types +const ( + ApplicationJSON = context.ApplicationJSON + ApplicationXML = context.ApplicationXML + ApplicationYAML = context.ApplicationYAML + TextXML = context.TextXML +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return (*Context)(context.NewContext()) +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context context.Context + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + (*context.Context)(ctx).Reset(rw, r) +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + (*context.Context)(ctx).Redirect(status, localurl) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + (*context.Context)(ctx).Abort(status, body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + (*context.Context)(ctx).WriteString(content) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return (*context.Context)(ctx).GetCookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + (*context.Context)(ctx).SetCookie(name, value, others) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + return (*context.Context)(ctx).GetSecureCookie(Secret, key) +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*context.Context)(ctx).SetSecureCookie(Secret, name, value, others) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + return (*context.Context)(ctx).XSRFToken(key, expire) +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + return (*context.Context)(ctx).CheckXSRFCookie() +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + (*context.Context)(ctx).RenderMethodResult(result) +} + +// Response is a wrapper for the http.ResponseWriter +// started set to true if response was written to then don't execute other handler +type Response context.Response + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + return (*context.Response)(r).Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + (*context.Response)(r).WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return (*context.Response)(r).Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + (*context.Response)(r).Flush() +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + return (*context.Response)(r).CloseNotify() +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + return (*context.Response)(r).Pusher() +} diff --git a/pkg/adapter/context/input.go b/pkg/adapter/context/input.go new file mode 100644 index 00000000..a1d08855 --- /dev/null +++ b/pkg/adapter/context/input.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput context.BeegoInput + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return (*BeegoInput)(context.NewInput()) +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + (*context.BeegoInput)(input).Reset((*context.Context)(ctx)) +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return (*context.BeegoInput)(input).Protocol() +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return (*context.BeegoInput)(input).URL() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return (*context.BeegoInput)(input).Site() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + return (*context.BeegoInput)(input).Scheme() +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return (*context.BeegoInput)(input).Domain() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + return (*context.BeegoInput)(input).Host() +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return (*context.BeegoInput)(input).Method() +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return (*context.BeegoInput)(input).Is(method) +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return (*context.BeegoInput)(input).IsGet() +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return (*context.BeegoInput)(input).IsPost() +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return (*context.BeegoInput)(input).IsHead() +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return (*context.BeegoInput)(input).IsOptions() +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return (*context.BeegoInput)(input).IsPut() +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return (*context.BeegoInput)(input).IsDelete() +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return (*context.BeegoInput)(input).IsPatch() +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return (*context.BeegoInput)(input).IsAjax() +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return (*context.BeegoInput)(input).IsSecure() +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return (*context.BeegoInput)(input).IsWebsocket() +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return (*context.BeegoInput)(input).IsUpload() +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return (*context.BeegoInput)(input).AcceptsHTML() +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return (*context.BeegoInput)(input).AcceptsXML() +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return (*context.BeegoInput)(input).AcceptsJSON() +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return (*context.BeegoInput)(input).AcceptsYAML() +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + return (*context.BeegoInput)(input).IP() +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + return (*context.BeegoInput)(input).Proxy() +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return (*context.BeegoInput)(input).Referer() +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return (*context.BeegoInput)(input).Refer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + return (*context.BeegoInput)(input).SubDomains() +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + return (*context.BeegoInput)(input).Port() +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return (*context.BeegoInput)(input).UserAgent() +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return (*context.BeegoInput)(input).ParamsLen() +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + return (*context.BeegoInput)(input).Param(key) +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + return (*context.BeegoInput)(input).Params() +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + (*context.BeegoInput)(input).SetParam(key, val) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + (*context.BeegoInput)(input).ResetParams() +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + return (*context.BeegoInput)(input).Query(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return (*context.BeegoInput)(input).Header(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + return (*context.BeegoInput)(input).Cookie(key) +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return (*context.BeegoInput)(input).Session(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + return (*context.BeegoInput)(input).CopyBody(MaxMemory) +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + return (*context.BeegoInput)(input).Data() +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + return (*context.BeegoInput)(input).GetData(key) +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + (*context.BeegoInput)(input).SetData(key, val) +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + return (*context.BeegoInput)(input).ParseFormOrMulitForm(maxMemory) +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + return (*context.BeegoInput)(input).Bind(dest, key) +} diff --git a/pkg/adapter/context/output.go b/pkg/adapter/context/output.go new file mode 100644 index 00000000..8e2a7f7d --- /dev/null +++ b/pkg/adapter/context/output.go @@ -0,0 +1,154 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput context.BeegoOutput + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return (*BeegoOutput)(context.NewOutput()) +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + (*context.BeegoOutput)(output).Reset((*context.Context)(ctx)) +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + (*context.BeegoOutput)(output).Header(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + return (*context.BeegoOutput)(output).Body(content) +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + (*context.BeegoOutput)(output).Cookie(name, value, others) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + return (*context.BeegoOutput)(output).JSON(data, hasIndent, encoding) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + return (*context.BeegoOutput)(output).YAML(data) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).JSONP(data, hasIndent) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).XML(data, hasIndent) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + (*context.BeegoOutput)(output).ServeFormatted(data, hasIndent, hasEncode...) +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + (*context.BeegoOutput)(output).Download(file, filename...) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + (*context.BeegoOutput)(output).ContentType(ext) +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + (*context.BeegoOutput)(output).SetStatus(status) +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return (*context.BeegoOutput)(output).IsCachable() +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return (*context.BeegoOutput)(output).IsEmpty() +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return (*context.BeegoOutput)(output).IsOk() +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return (*context.BeegoOutput)(output).IsSuccessful() +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return (*context.BeegoOutput)(output).IsRedirect() +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return (*context.BeegoOutput)(output).IsForbidden() +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return (*context.BeegoOutput)(output).IsNotFound() +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return (*context.BeegoOutput)(output).IsClientError() +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return (*context.BeegoOutput)(output).IsServerError() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + (*context.BeegoOutput)(output).Session(name, value) +} diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go new file mode 100644 index 00000000..7e352007 --- /dev/null +++ b/pkg/adapter/context/renderer.go @@ -0,0 +1,9 @@ +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// Renderer defines an http response renderer +type Renderer context.Renderer + diff --git a/pkg/adapter/context/response.go b/pkg/adapter/context/response.go new file mode 100644 index 00000000..24e196a4 --- /dev/null +++ b/pkg/adapter/context/response.go @@ -0,0 +1,26 @@ +package context + +import ( + "net/http" + "strconv" +) + +const ( + // BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + // NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} From 8fc4f8847c4f9d887605ab4c56461a2feb0549de Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 20:43:35 +0800 Subject: [PATCH 140/207] adapt grace and metric --- pkg/adapter/context/renderer.go | 1 - pkg/adapter/grace/grace.go | 96 ++++++++++++++++++++++++++ pkg/adapter/grace/server.go | 48 +++++++++++++ pkg/adapter/metric/prometheus.go | 99 +++++++++++++++++++++++++++ pkg/adapter/metric/prometheus_test.go | 42 ++++++++++++ 5 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 pkg/adapter/grace/grace.go create mode 100644 pkg/adapter/grace/server.go create mode 100644 pkg/adapter/metric/prometheus.go create mode 100644 pkg/adapter/metric/prometheus_test.go diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go index 7e352007..763fb9c4 100644 --- a/pkg/adapter/context/renderer.go +++ b/pkg/adapter/context/renderer.go @@ -6,4 +6,3 @@ import ( // Renderer defines an http response renderer type Renderer context.Renderer - diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go new file mode 100644 index 00000000..67cd4a1e --- /dev/null +++ b/pkg/adapter/grace/grace.go @@ -0,0 +1,96 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = grace.DefaultTimeout + +) + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + return (*Server)(grace.NewServer(addr, handler)) +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/adapter/grace/server.go b/pkg/adapter/grace/server.go new file mode 100644 index 00000000..31c13f18 --- /dev/null +++ b/pkg/adapter/grace/server.go @@ -0,0 +1,48 @@ +package grace + +import ( + "os" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +// Server embedded http.Server +type Server grace.Server + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + return (*grace.Server)(srv).Serve() +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + return (*grace.Server)(srv).ListenAndServe() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + return (*grace.Server)(srv).ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) error { + return (*grace.Server)(srv).ListenAndServeMutualTLS(certFile, keyFile, trustFile) +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) error { + return (*grace.Server)(srv).RegisterSignalHook(ppFlag, sig, f) +} diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go new file mode 100644 index 00000000..1d3488c6 --- /dev/null +++ b/pkg/adapter/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/server/web" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": web.BConfig.AppName, + "build_version": web.BuildVersion, + "build_revision": web.BuildGitRevision, + "build_status": web.BuildStatus, + "build_tag": web.BuildTag, + "build_time": strings.Replace(web.BuildTime, "--", " ", 1), + "go_version": web.GoVersion, + "git_branch": web.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := web.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go new file mode 100644 index 00000000..d82a6dec --- /dev/null +++ b/pkg/adapter/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} From bdd8df675135f0c3b716130cbdf363c0ddf79567 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 21:01:54 +0800 Subject: [PATCH 141/207] adapt migration --- pkg/adapter/grace/grace.go | 2 - pkg/adapter/migration/ddl.go | 198 +++++++++++++++++++++++++++++ pkg/adapter/migration/doc.go | 32 +++++ pkg/adapter/migration/migration.go | 111 ++++++++++++++++ pkg/client/orm/migration/ddl.go | 52 ++++---- 5 files changed, 367 insertions(+), 28 deletions(-) create mode 100644 pkg/adapter/migration/ddl.go create mode 100644 pkg/adapter/migration/doc.go create mode 100644 pkg/adapter/migration/migration.go diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go index 67cd4a1e..3775e395 100644 --- a/pkg/adapter/grace/grace.go +++ b/pkg/adapter/grace/grace.go @@ -66,7 +66,6 @@ const ( var ( - // DefaultReadTimeOut is the HTTP read timeout DefaultReadTimeOut time.Duration // DefaultWriteTimeOut is the HTTP Write timeout @@ -75,7 +74,6 @@ var ( DefaultMaxHeaderBytes int // DefaultTimeout is the shutdown server's timeout. default is 60s DefaultTimeout = grace.DefaultTimeout - ) // NewServer returns a new graceServer. diff --git a/pkg/adapter/migration/ddl.go b/pkg/adapter/migration/ddl.go new file mode 100644 index 00000000..97e45dec --- /dev/null +++ b/pkg/adapter/migration/ddl.go @@ -0,0 +1,198 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// Index struct defines the structure of Index Columns +type Index migration.Index + +// Unique struct defines a single unique key combination +type Unique migration.Unique + +// Column struct defines a single column of a table +type Column migration.Column + +// Foreign struct defines a single foreign relationship +type Foreign migration.Foreign + +// RenameColumn struct allows renaming of columns +type RenameColumn migration.RenameColumn + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + (*migration.Migration)(m).CreateTable(tablename, engine, charset, p...) +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + (*migration.Migration)(m).AlterTable(tablename) +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + return (*Column)((*migration.Migration)(m).NewCol(name)) +} + +// PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + return (*Column)((*migration.Migration)(m).PriCol(name)) +} + +// UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + return (*Column)((*migration.Migration)(m).UniCol(uni, name)) +} + +// ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + return (*Foreign)((*migration.Migration)(m).ForeignCol(colname, foreigncol, foreigntable)) +} + +// SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + (*migration.Foreign)(foreign).SetOnDelete(del) + return foreign +} + +// SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + (*migration.Foreign)(foreign).SetOnUpdate(update) + return foreign +} + +// Remove marks the columns to be removed. +// it allows reverse m to create the column. +func (c *Column) Remove() { + (*migration.Column)(c).Remove() +} + +// SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + (*migration.Column)(c).SetAuto(inc) + return c +} + +// SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + (*migration.Column)(c).SetNullable(null) + return c +} + +// SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + (*migration.Column)(c).SetDefault(def) + return c +} + +// SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + (*migration.Column)(c).SetUnsigned(unsign) + return c +} + +// SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + (*migration.Column)(c).SetDataType(dataType) + return c +} + +// SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldNullable(null) + return c +} + +// SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDefault(def) + return c +} + +// SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldUnsigned(unsign) + return c +} + +// SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDataType(dataType) + return c +} + +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + (*migration.Column)(c).SetPrimary((*migration.Migration)(m)) + return c +} + +// AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + cls := toNewColumnsArray(columns) + (*migration.Unique)(unique).AddColumnsToUnique(cls...) + return unique +} + +// AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + cls := toNewColumnsArray(columns) + (*migration.Migration)(m).AddColumns(cls...) + return m +} + +func toNewColumnsArray(columns []*Column) []*migration.Column { + cls := make([]*migration.Column, 0, len(columns)) + for _, c := range columns { + cls = append(cls, (*migration.Column)(c)) + } + return cls +} + +// AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + (*migration.Migration)(m).AddPrimary((*migration.Column)(primary)) + return m +} + +// AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + (*migration.Migration)(m).AddUnique((*migration.Unique)(unique)) + return m +} + +// AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + (*migration.Migration)(m).AddForeign((*migration.Foreign)(foreign)) + return m +} + +// AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + (*migration.Migration)(m).AddIndex((*migration.Index)(index)) + return m +} + +// RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + return (*RenameColumn)((*migration.Migration)(m).RenameColumn(from, to)) +} + +// GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + return (*migration.Migration)(m).GetSQL() +} diff --git a/pkg/adapter/migration/doc.go b/pkg/adapter/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/adapter/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/adapter/migration/migration.go b/pkg/adapter/migration/migration.go new file mode 100644 index 00000000..4ee22e5a --- /dev/null +++ b/pkg/adapter/migration/migration.go @@ -0,0 +1,111 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +// Migration defines the migrations by either SQL or DDL +type Migration migration.Migration + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + (*migration.Migration)(m).Up() +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + (*migration.Migration)(m).Down() +} + +// Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + (*migration.Migration)(m).Migrate(migrationType) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + (*migration.Migration)(m).SQL(sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + (*migration.Migration)(m).Reset() +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + return (*migration.Migration)(m).Exec(name, status) +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + return (*migration.Migration)(m).GetCreated() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + return migration.Register(name, m) +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + return migration.Upgrade(lasttime) +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + return migration.Rollback(name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + return migration.Reset() +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + return migration.Refresh() +} diff --git a/pkg/client/orm/migration/ddl.go b/pkg/client/orm/migration/ddl.go index c21352a8..e8b13212 100644 --- a/pkg/client/orm/migration/ddl.go +++ b/pkg/client/orm/migration/ddl.go @@ -31,7 +31,7 @@ type Unique struct { Columns []*Column } -//Column struct defines a single column of a table +// Column struct defines a single column of a table type Column struct { Name string Inc string @@ -84,7 +84,7 @@ func (m *Migration) NewCol(name string) *Column { return col } -//PriCol creates a new primary column and attaches it to m struct +// PriCol creates a new primary column and attaches it to m struct func (m *Migration) PriCol(name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -92,7 +92,7 @@ func (m *Migration) PriCol(name string) *Column { return col } -//UniCol creates / appends columns to specified unique key and attaches it to m struct +// UniCol creates / appends columns to specified unique key and attaches it to m struct func (m *Migration) UniCol(uni, name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -114,7 +114,7 @@ func (m *Migration) UniCol(uni, name string) *Column { return col } -//ForeignCol creates a new foreign column and returns the instance of column +// ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} @@ -123,25 +123,25 @@ func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreig return foreign } -//SetOnDelete sets the on delete of foreign +// SetOnDelete sets the on delete of foreign func (foreign *Foreign) SetOnDelete(del string) *Foreign { foreign.OnDelete = "ON DELETE" + del return foreign } -//SetOnUpdate sets the on update of foreign +// SetOnUpdate sets the on update of foreign func (foreign *Foreign) SetOnUpdate(update string) *Foreign { foreign.OnUpdate = "ON UPDATE" + update return foreign } -//Remove marks the columns to be removed. -//it allows reverse m to create the column. +// Remove marks the columns to be removed. +// it allows reverse m to create the column. func (c *Column) Remove() { c.remove = true } -//SetAuto enables auto_increment of column (can be used once) +// SetAuto enables auto_increment of column (can be used once) func (c *Column) SetAuto(inc bool) *Column { if inc { c.Inc = "auto_increment" @@ -149,7 +149,7 @@ func (c *Column) SetAuto(inc bool) *Column { return c } -//SetNullable sets the column to be null +// SetNullable sets the column to be null func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" @@ -160,13 +160,13 @@ func (c *Column) SetNullable(null bool) *Column { return c } -//SetDefault sets the default value, prepend with "DEFAULT " +// SetDefault sets the default value, prepend with "DEFAULT " func (c *Column) SetDefault(def string) *Column { c.Default = "DEFAULT " + def return c } -//SetUnsigned sets the column to be unsigned int +// SetUnsigned sets the column to be unsigned int func (c *Column) SetUnsigned(unsign bool) *Column { if unsign { c.Unsign = "UNSIGNED" @@ -174,13 +174,13 @@ func (c *Column) SetUnsigned(unsign bool) *Column { return c } -//SetDataType sets the dataType of the column +// SetDataType sets the dataType of the column func (c *Column) SetDataType(dataType string) *Column { c.DataType = dataType return c } -//SetOldNullable allows reverting to previous nullable on reverse ms +// SetOldNullable allows reverting to previous nullable on reverse ms func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" @@ -191,13 +191,13 @@ func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { return c } -//SetOldDefault allows reverting to previous default on reverse ms +// SetOldDefault allows reverting to previous default on reverse ms func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { c.OldDefault = def return c } -//SetOldUnsigned allows reverting to previous unsgined on reverse ms +// SetOldUnsigned allows reverting to previous unsgined on reverse ms func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { if unsign { c.OldUnsign = "UNSIGNED" @@ -205,19 +205,19 @@ func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { return c } -//SetOldDataType allows reverting to previous datatype on reverse ms +// SetOldDataType allows reverting to previous datatype on reverse ms func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { c.OldDataType = dataType return c } -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) func (c *Column) SetPrimary(m *Migration) *Column { m.Primary = append(m.Primary, c) return c } -//AddColumnsToUnique adds the columns to Unique Struct +// AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { unique.Columns = append(unique.Columns, columns...) @@ -225,7 +225,7 @@ func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { return unique } -//AddColumns adds columns to m struct +// AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { m.Columns = append(m.Columns, columns...) @@ -233,38 +233,38 @@ func (m *Migration) AddColumns(columns ...*Column) *Migration { return m } -//AddPrimary adds the column to primary in m struct +// AddPrimary adds the column to primary in m struct func (m *Migration) AddPrimary(primary *Column) *Migration { m.Primary = append(m.Primary, primary) return m } -//AddUnique adds the column to unique in m struct +// AddUnique adds the column to unique in m struct func (m *Migration) AddUnique(unique *Unique) *Migration { m.Uniques = append(m.Uniques, unique) return m } -//AddForeign adds the column to foreign in m struct +// AddForeign adds the column to foreign in m struct func (m *Migration) AddForeign(foreign *Foreign) *Migration { m.Foreigns = append(m.Foreigns, foreign) return m } -//AddIndex adds the column to index in m struct +// AddIndex adds the column to index in m struct func (m *Migration) AddIndex(index *Index) *Migration { m.Indexes = append(m.Indexes, index) return m } -//RenameColumn allows renaming of columns +// RenameColumn allows renaming of columns func (m *Migration) RenameColumn(from, to string) *RenameColumn { rename := &RenameColumn{OldName: from, NewName: to} m.Renames = append(m.Renames, rename) return rename } -//GetSQL returns the generated sql depending on ModifyType +// GetSQL returns the generated sql depending on ModifyType func (m *Migration) GetSQL() (sql string) { sql = "" switch m.ModifyType { From cbd51616f17361706060c8e7d1dab4265e519d8c Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 23:23:48 +0800 Subject: [PATCH 142/207] adapter: validation module --- pkg/adapter/validation/util.go | 62 +++ pkg/adapter/validation/validation.go | 274 ++++++++++ pkg/adapter/validation/validation_test.go | 609 ++++++++++++++++++++++ pkg/adapter/validation/validators.go | 512 ++++++++++++++++++ 4 files changed, 1457 insertions(+) create mode 100644 pkg/adapter/validation/util.go create mode 100644 pkg/adapter/validation/validation.go create mode 100644 pkg/adapter/validation/validation_test.go create mode 100644 pkg/adapter/validation/validators.go diff --git a/pkg/adapter/validation/util.go b/pkg/adapter/validation/util.go new file mode 100644 index 00000000..729712e0 --- /dev/null +++ b/pkg/adapter/validation/util.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "reflect" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +const ( + // ValidTag struct tag + ValidTag = validation.ValidTag + + LabelTag = validation.LabelTag +) + +var ( + ErrInt64On32 = validation.ErrInt64On32 +) + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + return validation.AddCustomFunc(name, func(v *validation.Validation, obj interface{}, key string) { + f((*Validation)(v), obj, key) + }) +} + +// ValidFunc Valid function type +type ValidFunc validation.ValidFunc + +// Funcs Validate function map +type Funcs validation.Funcs + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + return (validation.Funcs(f)).Call(name, params...) +} diff --git a/pkg/adapter/validation/validation.go b/pkg/adapter/validation/validation.go new file mode 100644 index 00000000..1cdb8dda --- /dev/null +++ b/pkg/adapter/validation/validation.go @@ -0,0 +1,274 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "regexp" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error validation.Error + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result validation.Result + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation validation.Validation + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + (*validation.Validation)(v).Clear() +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return (*validation.Validation)(v).HasErrors() +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + newErrors := (*validation.Validation)(v).ErrorMap() + res := make(map[string][]*Error, len(newErrors)) + for n, es := range newErrors { + errs := make([]*Error, 0, len(es)) + + for _, e := range es { + errs = append(errs, (*Error)(e)) + } + + res[n] = errs + } + return res +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + return (*Result)((*validation.Validation)(v).Error(message, args...)) +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Required(obj, key)) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).Min(obj, min, key)) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Max(obj, max, key)) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Range(obj, min, max, key)) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).MinSize(obj, min, key)) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).MaxSize(obj, max, key)) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return (*Result)((*validation.Validation)(v).Length(obj, n, key)) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Alpha(obj, key)) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Numeric(obj, key)) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaNumeric(obj, key)) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).Match(obj, regex, key)) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).NoMatch(obj, regex, key)) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaDash(obj, key)) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Email(obj, key)) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).IP(obj, key)) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Base64(obj, key)) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Mobile(obj, key)) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Tel(obj, key)) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Phone(obj, key)) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).ZipCode(obj, key)) +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + (*validation.Validation)(v).AddError(key, message) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + return (*Error)((*validation.Validation)(v).SetError(fieldName, errMsg)) +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + vldts := make([]validation.Validator, 0, len(checks)) + for _, v := range checks { + vldts = append(vldts, validation.Validator(v)) + } + return (*Result)((*validation.Validation)(v).Check(obj, vldts...)) +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + return (*validation.Validation)(v).Valid(obj) +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + return (*validation.Validation)(v).RecursiveValid(objc) +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + (*validation.Validation)(v).CanSkipAlso(skipFunc) +} diff --git a/pkg/adapter/validation/validation_test.go b/pkg/adapter/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/adapter/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/adapter/validation/validators.go b/pkg/adapter/validation/validators.go new file mode 100644 index 00000000..1a063749 --- /dev/null +++ b/pkg/adapter/validation/validators.go @@ -0,0 +1,512 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "sync" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = validation.CanSkipFuncs + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + validation.SetDefaultMessage(msg) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required validation.Required + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + return validation.Required(r).IsSatisfied(obj) +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return validation.Required(r).DefaultMessage() +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return validation.Required(r).GetKey() +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return validation.Required(r).GetLimitValue() +} + +// Min check struct +type Min validation.Min + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + return validation.Min(m).IsSatisfied(obj) +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return validation.Min(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return validation.Min(m).GetKey() +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return validation.Min(m).GetLimitValue() +} + +// Max validate struct +type Max validation.Max + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + return validation.Max(m).IsSatisfied(obj) +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return validation.Max(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return validation.Max(m).GetKey() +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return validation.Max(m).GetLimitValue() +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range validation.Range + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return validation.Range(r).IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return validation.Range(r).DefaultMessage() +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return validation.Range(r).GetKey() +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return validation.Range(r).GetLimitValue() +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize validation.MinSize + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + return validation.MinSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return validation.MinSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return validation.MinSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return validation.MinSize(m).GetLimitValue() +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize validation.MaxSize + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + return validation.MaxSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return validation.MaxSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return validation.MaxSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return validation.MaxSize(m).GetLimitValue() +} + +// Length Requires an array or string to be exactly a given length. +type Length validation.Length + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + return validation.Length(l).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return validation.Length(l).DefaultMessage() +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return validation.Length(l).GetKey() +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return validation.Length(l).GetLimitValue() +} + +// Alpha check the alpha +type Alpha validation.Alpha + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + return validation.Alpha(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return validation.Alpha(a).DefaultMessage() +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return validation.Alpha(a).GetKey() +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return validation.Alpha(a).GetLimitValue() +} + +// Numeric check number +type Numeric validation.Numeric + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + return validation.Numeric(n).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return validation.Numeric(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return validation.Numeric(n).GetKey() +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return validation.Numeric(n).GetLimitValue() +} + +// AlphaNumeric check alpha and number +type AlphaNumeric validation.AlphaNumeric + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + return validation.AlphaNumeric(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return validation.AlphaNumeric(a).DefaultMessage() +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return validation.AlphaNumeric(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return validation.AlphaNumeric(a).GetLimitValue() +} + +// Match Requires a string to match a given regex. +type Match validation.Match + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return validation.Match(m).IsSatisfied(obj) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return validation.Match(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return validation.Match(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return validation.Match(m).GetLimitValue() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch validation.NoMatch + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return validation.NoMatch(n).IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return validation.NoMatch(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return validation.NoMatch(n).GetKey() +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return validation.NoMatch(n).GetLimitValue() +} + +// AlphaDash check not Alpha +type AlphaDash validation.AlphaDash + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return validation.AlphaDash(a).DefaultMessage() +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return validation.AlphaDash(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return validation.AlphaDash(a).GetLimitValue() +} + +// Email check struct +type Email validation.Email + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return validation.Email(e).DefaultMessage() +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return validation.Email(e).GetKey() +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return validation.Email(e).GetLimitValue() +} + +// IP check struct +type IP validation.IP + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return validation.IP(i).DefaultMessage() +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return validation.IP(i).GetKey() +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return validation.IP(i).GetLimitValue() +} + +// Base64 check struct +type Base64 validation.Base64 + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return validation.Base64(b).DefaultMessage() +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return validation.Base64(b).GetKey() +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return validation.Base64(b).GetLimitValue() +} + +// Mobile check struct +type Mobile validation.Mobile + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return validation.Mobile(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return validation.Mobile(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return validation.Mobile(m).GetLimitValue() +} + +// Tel check telephone struct +type Tel validation.Tel + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return validation.Tel(t).DefaultMessage() +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return validation.Tel(t).GetKey() +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return validation.Tel(t).GetLimitValue() +} + +// Phone just for chinese telephone or mobile phone number +type Phone validation.Phone + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return validation.Phone(p).IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return validation.Phone(p).DefaultMessage() +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return validation.Phone(p).GetKey() +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return validation.Phone(p).GetLimitValue() +} + +// ZipCode check the zip struct +type ZipCode validation.ZipCode + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return validation.ZipCode(z).DefaultMessage() +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return validation.ZipCode(z).GetKey() +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return validation.ZipCode(z).GetLimitValue() +} From 3530457ff9a51e721be139bec94de2299a027197 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 3 Sep 2020 21:34:46 +0800 Subject: [PATCH 143/207] Adapter: toolbox module --- pkg/adapter/toolbox/healthcheck.go | 52 +++++ pkg/adapter/toolbox/profile.go | 50 +++++ pkg/adapter/toolbox/profile_test.go | 28 +++ pkg/adapter/toolbox/statistics.go | 50 +++++ pkg/adapter/toolbox/statistics_test.go | 40 ++++ pkg/adapter/toolbox/task.go | 286 +++++++++++++++++++++++++ pkg/adapter/toolbox/task_test.go | 63 ++++++ pkg/task/task.go | 6 +- 8 files changed, 572 insertions(+), 3 deletions(-) create mode 100644 pkg/adapter/toolbox/healthcheck.go create mode 100644 pkg/adapter/toolbox/profile.go create mode 100644 pkg/adapter/toolbox/profile_test.go create mode 100644 pkg/adapter/toolbox/statistics.go create mode 100644 pkg/adapter/toolbox/statistics_test.go create mode 100644 pkg/adapter/toolbox/task.go create mode 100644 pkg/adapter/toolbox/task_test.go diff --git a/pkg/adapter/toolbox/healthcheck.go b/pkg/adapter/toolbox/healthcheck.go new file mode 100644 index 00000000..56be8089 --- /dev/null +++ b/pkg/adapter/toolbox/healthcheck.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +import ( + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +// AdminCheckList holds health checker map +// Deprecated using governor.AdminCheckList +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker governor.HealthChecker + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + governor.AddHealthCheck(name, hc) + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/adapter/toolbox/profile.go b/pkg/adapter/toolbox/profile.go new file mode 100644 index 00000000..16cf80b1 --- /dev/null +++ b/pkg/adapter/toolbox/profile.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "io" + "os" + "time" + + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + governor.ProcessInput(input, w) +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + governor.MemProf(w) +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + governor.GetCPUProfile(w) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + governor.PrintGCSummary(w) +} diff --git a/pkg/adapter/toolbox/profile_test.go b/pkg/adapter/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/adapter/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/adapter/toolbox/statistics.go b/pkg/adapter/toolbox/statistics.go new file mode 100644 index 00000000..b7d3bda9 --- /dev/null +++ b/pkg/adapter/toolbox/statistics.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Statistics struct +type Statistics web.Statistics + +// URLMap contains several statistics struct to log different data +type URLMap web.URLMap + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + (*web.URLMap)(m).AddStatistics(requestMethod, requestURL, requestController, requesttime) +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + return (*web.URLMap)(m).GetMap() +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + return (*web.URLMap)(m).GetMapData() +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = (*URLMap)(web.StatisticsMap) +} diff --git a/pkg/adapter/toolbox/statistics_test.go b/pkg/adapter/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/adapter/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/adapter/toolbox/task.go b/pkg/adapter/toolbox/task.go new file mode 100644 index 00000000..2a6d9aa6 --- /dev/null +++ b/pkg/adapter/toolbox/task.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "context" + "sort" + "time" + + "github.com/astaxie/beego/pkg/task" +) + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule task.Schedule + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// Deprecated +type Task struct { + // Deprecated + Taskname string + // Deprecated + Spec *Schedule + // Deprecated + SpecStr string + // Deprecated + DoFunc TaskFunc + // Deprecated + Prev time.Time + // Deprecated + Next time.Time + // Deprecated + Errlist []*taskerr // like errtime:errinfo + // Deprecated + ErrLimit int // max length for the errlist, 0 stand for no limit + + delegate *task.Task +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := task.NewTask(tname, spec, func(ctx context.Context) error { + return f() + }) + return &Task{ + delegate: task, + } +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + t.initDelegate() + + return t.delegate.GetSpec(context.Background()) +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + + t.initDelegate() + + return t.delegate.GetStatus(context.Background()) +} + +// Run run all tasks +func (t *Task) Run() error { + t.initDelegate() + return t.delegate.Run(context.Background()) +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.initDelegate() + t.delegate.SetNext(context.Background(), now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + t.initDelegate() + return t.delegate.GetNext(context.Background()) +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.initDelegate() + t.delegate.SetPrev(context.Background(), now) +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + t.initDelegate() + return t.delegate.GetPrev(context.Background()) +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//    -:duration +// /n : do as n times of time duration +// /////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.initDelegate() + t.delegate.SetCron(spec) +} + +func (t *Task) initDelegate() { + if t.delegate == nil { + t.delegate = &task.Task{ + Taskname: t.Taskname, + Spec: (*task.Schedule)(t.Spec), + SpecStr: t.SpecStr, + DoFunc: func(ctx context.Context) error { + return t.DoFunc() + }, + Prev: t.Prev, + Next: t.Next, + ErrLimit: t.ErrLimit, + } + } +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + return (*task.Schedule)(s).Next(t) +} + +// StartTask start all tasks +func StartTask() { + task.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + task.StopTask() +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + task.AddTask(taskname, &oldToNewAdapter{delegate: t}) +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + task.DeleteTask(taskname) +} + +// MapSorter sort map for tasker +type MapSorter task.MapSorter + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + + newTaskerMap := make(map[string]task.Tasker, len(m)) + + for key, value := range m { + newTaskerMap[key] = &oldToNewAdapter{ + delegate: value, + } + } + + return (*MapSorter)(task.NewMapSorter(newTaskerMap)) +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext(context.Background()).IsZero() { + return false + } + if ms.Vals[j].GetNext(context.Background()).IsZero() { + return true + } + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func init() { + AdminTaskList = make(map[string]Tasker) +} + +type oldToNewAdapter struct { + delegate Tasker +} + +func (o *oldToNewAdapter) GetSpec(ctx context.Context) string { + return o.delegate.GetSpec() +} + +func (o *oldToNewAdapter) GetStatus(ctx context.Context) string { + return o.delegate.GetStatus() +} + +func (o *oldToNewAdapter) Run(ctx context.Context) error { + return o.delegate.Run() +} + +func (o *oldToNewAdapter) SetNext(ctx context.Context, t time.Time) { + o.delegate.SetNext(t) +} + +func (o *oldToNewAdapter) GetNext(ctx context.Context) time.Time { + return o.delegate.GetNext() +} + +func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { + o.delegate.SetPrev(t) +} + +func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { + return o.delegate.GetPrev() +} diff --git a/pkg/adapter/toolbox/task_test.go b/pkg/adapter/toolbox/task_test.go new file mode 100644 index 00000000..596bc9c5 --- /dev/null +++ b/pkg/adapter/toolbox/task_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/task/task.go b/pkg/task/task.go index e2962000..bcadb956 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -83,7 +83,7 @@ type Schedule struct { } // TaskFunc task func type -type TaskFunc func() error +type TaskFunc func(ctx context.Context) error // Tasker task interface type Tasker interface { @@ -148,8 +148,8 @@ func (t *Task) GetStatus(context.Context) string { } // Run run all tasks -func (t *Task) Run(context.Context) error { - err := t.DoFunc() +func (t *Task) Run(ctx context.Context) error { + err := t.DoFunc(ctx) if err != nil { index := t.errCnt % t.ErrLimit t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} From 8ef9965eef3250a5578739258c9f12315ead1771 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 3 Sep 2020 23:36:09 +0800 Subject: [PATCH 144/207] Adapter: session module --- .../session/couchbase/sess_couchbase.go | 118 ++++++ pkg/adapter/session/ledis/ledis_session.go | 86 +++++ pkg/adapter/session/memcache/sess_memcache.go | 118 ++++++ pkg/adapter/session/mysql/sess_mysql.go | 135 +++++++ .../session/postgres/sess_postgresql.go | 139 ++++++++ pkg/adapter/session/provider_adapter.go | 104 ++++++ pkg/adapter/session/redis/sess_redis.go | 121 +++++++ .../session/redis_cluster/redis_cluster.go | 120 +++++++ .../redis_sentinel/sess_redis_sentinel.go | 121 +++++++ .../sess_redis_sentinel_test.go | 90 +++++ pkg/adapter/session/sess_cookie.go | 114 ++++++ pkg/adapter/session/sess_cookie_test.go | 105 ++++++ pkg/adapter/session/sess_file.go | 106 ++++++ pkg/adapter/session/sess_file_test.go | 336 ++++++++++++++++++ pkg/adapter/session/sess_mem.go | 106 ++++++ pkg/adapter/session/sess_mem_test.go | 58 +++ pkg/adapter/session/sess_test.go | 51 +++ pkg/adapter/session/sess_utils.go | 29 ++ pkg/adapter/session/session.go | 166 +++++++++ pkg/adapter/session/ssdb/sess_ssdb.go | 84 +++++ pkg/adapter/session/store_adapter.go | 84 +++++ pkg/infrastructure/session/sess_cookie.go | 2 +- pkg/infrastructure/session/sess_mem.go | 6 +- 23 files changed, 2395 insertions(+), 4 deletions(-) create mode 100644 pkg/adapter/session/couchbase/sess_couchbase.go create mode 100644 pkg/adapter/session/ledis/ledis_session.go create mode 100644 pkg/adapter/session/memcache/sess_memcache.go create mode 100644 pkg/adapter/session/mysql/sess_mysql.go create mode 100644 pkg/adapter/session/postgres/sess_postgresql.go create mode 100644 pkg/adapter/session/provider_adapter.go create mode 100644 pkg/adapter/session/redis/sess_redis.go create mode 100644 pkg/adapter/session/redis_cluster/redis_cluster.go create mode 100644 pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/adapter/session/sess_cookie.go create mode 100644 pkg/adapter/session/sess_cookie_test.go create mode 100644 pkg/adapter/session/sess_file.go create mode 100644 pkg/adapter/session/sess_file_test.go create mode 100644 pkg/adapter/session/sess_mem.go create mode 100644 pkg/adapter/session/sess_mem_test.go create mode 100644 pkg/adapter/session/sess_test.go create mode 100644 pkg/adapter/session/sess_utils.go create mode 100644 pkg/adapter/session/session.go create mode 100644 pkg/adapter/session/ssdb/sess_ssdb.go create mode 100644 pkg/adapter/session/store_adapter.go diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..bce09641 --- /dev/null +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beecb "github.com/astaxie/beego/pkg/infrastructure/session/couchbase" +) + +// SessionStore store each session +type SessionStore beecb.SessionStore + +// Provider couchabse provided +type Provider beecb.Provider + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + return (*beecb.SessionStore)(cs).Set(context.Background(), key, value) +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + return (*beecb.SessionStore)(cs).Get(context.Background(), key) +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + return (*beecb.SessionStore)(cs).Delete(context.Background(), key) +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + return (*beecb.SessionStore)(cs).Flush(context.Background()) +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return (*beecb.SessionStore)(cs).SessionID(context.Background()) +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beecb.SessionStore)(cs).SessionRelease(context.Background(), w) +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beecb.Provider)(cp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + res, _ := (*beecb.Provider)(cp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + return (*beecb.Provider)(cp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { + (*beecb.Provider)(cp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return (*beecb.Provider)(cp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go new file mode 100644 index 00000000..96198837 --- /dev/null +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -0,0 +1,86 @@ +// Package ledis provide session Provider +package ledis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beeLedis "github.com/astaxie/beego/pkg/infrastructure/session/ledis" +) + +// SessionStore ledis session store +type SessionStore beeLedis.SessionStore + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + return (*beeLedis.SessionStore)(ls).Set(context.Background(), key, value) +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + return (*beeLedis.SessionStore)(ls).Get(context.Background(), key) +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + return (*beeLedis.SessionStore)(ls).Delete(context.Background(), key) +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + return (*beeLedis.SessionStore)(ls).Flush(context.Background()) +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return (*beeLedis.SessionStore)(ls).SessionID(context.Background()) +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeLedis.SessionStore)(ls).SessionRelease(context.Background(), w) +} + +// Provider ledis session provider +type Provider beeLedis.Provider + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeLedis.Provider)(lp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + res, _ := (*beeLedis.Provider)(lp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + return (*beeLedis.Provider)(lp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { + (*beeLedis.Provider)(lp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return (*beeLedis.Provider)(lp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go new file mode 100644 index 00000000..8afa79aa --- /dev/null +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beemem "github.com/astaxie/beego/pkg/infrastructure/session/memcache" +) + +// SessionStore memcache session store +type SessionStore beemem.SessionStore + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beemem.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beemem.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beemem.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + return (*beemem.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return (*beemem.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beemem.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// MemProvider memcache session provider +type MemProvider beemem.MemProvider + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*beemem.MemProvider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + res, _ := (*beemem.MemProvider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + return (*beemem.MemProvider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { + (*beemem.MemProvider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return (*beemem.MemProvider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go new file mode 100644 index 00000000..1850a380 --- /dev/null +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -0,0 +1,135 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/pkg/infrastructure/session/mysql" + + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = mysql.TableName + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore mysql.SessionStore + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*mysql.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*mysql.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + return (*mysql.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + return (*mysql.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return (*mysql.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*mysql.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider mysql session provider +type Provider mysql.Provider + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*mysql.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*mysql.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*mysql.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + (*mysql.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + return (*mysql.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..de1adbc4 --- /dev/null +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -0,0 +1,139 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + // import postgresql Driver + _ "github.com/lib/pq" + + "github.com/astaxie/beego/pkg/infrastructure/session/postgres" +) + +// SessionStore postgresql session store +type SessionStore postgres.SessionStore + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*postgres.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*postgres.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + return (*postgres.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + return (*postgres.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return (*postgres.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*postgres.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider postgresql session provider +type Provider postgres.Provider + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*postgres.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*postgres.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*postgres.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + (*postgres.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + return (*postgres.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go new file mode 100644 index 00000000..11177a4d --- /dev/null +++ b/pkg/adapter/session/provider_adapter.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type oldToNewProviderAdapter struct { + delegate Provider +} + +func (o *oldToNewProviderAdapter) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return o.delegate.SessionInit(gclifetime, config) +} + +func (o *oldToNewProviderAdapter) SessionRead(ctx context.Context, sid string) (session.Store, error) { + store, err := o.delegate.SessionRead(sid) + return &oldToNewStoreAdapter{ + delegate: store, + }, err +} + +func (o *oldToNewProviderAdapter) SessionExist(ctx context.Context, sid string) (bool, error) { + return o.delegate.SessionExist(sid), nil +} + +func (o *oldToNewProviderAdapter) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s, err := o.delegate.SessionRegenerate(oldsid, sid) + return &oldToNewStoreAdapter{ + delegate: s, + }, err +} + +func (o *oldToNewProviderAdapter) SessionDestroy(ctx context.Context, sid string) error { + return o.delegate.SessionDestroy(sid) +} + +func (o *oldToNewProviderAdapter) SessionAll(ctx context.Context) int { + return o.delegate.SessionAll() +} + +func (o *oldToNewProviderAdapter) SessionGC(ctx context.Context) { + o.delegate.SessionGC() +} + +type newToOldProviderAdapter struct { + delegate session.Provider +} + +func (n *newToOldProviderAdapter) SessionInit(gclifetime int64, config string) error { + return n.delegate.SessionInit(context.Background(), gclifetime, config) +} + +func (n *newToOldProviderAdapter) SessionRead(sid string) (Store, error) { + s, err := n.delegate.SessionRead(context.Background(), sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionExist(sid string) bool { + res, _ := n.delegate.SessionExist(context.Background(), sid) + return res +} + +func (n *newToOldProviderAdapter) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := n.delegate.SessionRegenerate(context.Background(), oldsid, sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionDestroy(sid string) error { + return n.delegate.SessionDestroy(context.Background(), sid) +} + +func (n *newToOldProviderAdapter) SessionAll() int { + return n.delegate.SessionAll(context.Background()) +} + +func (n *newToOldProviderAdapter) SessionGC() { + n.delegate.SessionGC(context.Background()) +} diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go new file mode 100644 index 00000000..6c521e50 --- /dev/null +++ b/pkg/adapter/session/redis/sess_redis.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeRedis "github.com/astaxie/beego/pkg/infrastructure/session/redis" +) + +// MaxPoolSize redis max pool size +var MaxPoolSize = beeRedis.MaxPoolSize + +// SessionStore redis session store +type SessionStore beeRedis.SessionStore + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beeRedis.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beeRedis.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beeRedis.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + return (*beeRedis.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return (*beeRedis.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeRedis.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis session provider +type Provider beeRedis.Provider + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeRedis.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*beeRedis.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*beeRedis.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*beeRedis.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*beeRedis.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..03a805e4 --- /dev/null +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -0,0 +1,120 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + cluster "github.com/astaxie/beego/pkg/infrastructure/session/redis_cluster" +) + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = cluster.MaxPoolSize + +// SessionStore redis_cluster session store +type SessionStore cluster.SessionStore + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*cluster.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*cluster.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + return (*cluster.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + return (*cluster.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return (*cluster.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*cluster.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_cluster session provider +type Provider cluster.Provider + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*cluster.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*cluster.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*cluster.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*cluster.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*cluster.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..f5eb8a4f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + sentinel "github.com/astaxie/beego/pkg/infrastructure/session/redis_sentinel" +) + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = sentinel.DefaultPoolSize + +// SessionStore redis_sentinel session store +type SessionStore sentinel.SessionStore + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*sentinel.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*sentinel.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + return (*sentinel.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + return (*sentinel.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return (*sentinel.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*sentinel.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_sentinel session provider +type Provider sentinel.Provider + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*sentinel.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*sentinel.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*sentinel.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*sentinel.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*sentinel.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..7c33985f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/pkg/adapter/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + // todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go new file mode 100644 index 00000000..32216040 --- /dev/null +++ b/pkg/adapter/session/sess_cookie.go @@ -0,0 +1,114 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore session.CookieSessionStore + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + return (*session.CookieSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + return (*session.CookieSessionStore)(st).Get(context.Background(), key) +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + return (*session.CookieSessionStore)(st).Delete(context.Background(), key) +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + return (*session.CookieSessionStore)(st).Flush(context.Background()) +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return (*session.CookieSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.CookieSessionStore)(st).SessionRelease(context.Background(), w) +} + +// CookieProvider Cookie session provider +type CookieProvider session.CookieProvider + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + return (*session.CookieProvider)(pder).SessionInit(context.Background(), maxlifetime, config) +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + res, _ := (*session.CookieProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return (*session.CookieProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { + (*session.CookieProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return (*session.CookieProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return (*session.CookieProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_cookie_test.go b/pkg/adapter/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/adapter/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go new file mode 100644 index 00000000..b9648998 --- /dev/null +++ b/pkg/adapter/session/sess_file.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// FileSessionStore File session store +type FileSessionStore session.FileSessionStore + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + return (*session.FileSessionStore)(fs).Set(context.Background(), key, value) +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + return (*session.FileSessionStore)(fs).Get(context.Background(), key) +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + return (*session.FileSessionStore)(fs).Delete(context.Background(), key) +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + return (*session.FileSessionStore)(fs).Flush(context.Background()) +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return (*session.FileSessionStore)(fs).SessionID(context.Background()) +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.FileSessionStore)(fs).SessionRelease(context.Background(), w) +} + +// FileProvider File session provider +type FileProvider session.FileProvider + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.FileProvider)(fp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + res, _ := (*session.FileProvider)(fp).SessionExist(context.Background(), sid) + return res +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + return (*session.FileProvider)(fp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + (*session.FileProvider)(fp).SessionGC(context.Background()) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + return (*session.FileProvider)(fp).SessionAll(context.Background()) +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} diff --git a/pkg/adapter/session/sess_file_test.go b/pkg/adapter/session/sess_file_test.go new file mode 100644 index 00000000..4c90a3ac --- /dev/null +++ b/pkg/adapter/session/sess_file_test.go @@ -0,0 +1,336 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go new file mode 100644 index 00000000..818c8329 --- /dev/null +++ b/pkg/adapter/session/sess_mem.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore session.MemSessionStore + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + return (*session.MemSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + return (*session.MemSessionStore)(st).Get(context.Background(), key) +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + return (*session.MemSessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + return (*session.MemSessionStore)(st).Flush(context.Background()) +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return (*session.MemSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.MemSessionStore)(st).SessionRelease(context.Background(), w) +} + +// MemProvider Implement the provider interface +type MemProvider session.MemProvider + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.MemProvider)(pder).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + res, _ := (*session.MemProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + return (*session.MemProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + (*session.MemProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return (*session.MemProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + return (*session.MemProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/adapter/session/sess_mem_test.go b/pkg/adapter/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/adapter/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/adapter/session/sess_test.go b/pkg/adapter/session/sess_test.go new file mode 100644 index 00000000..aba702ca --- /dev/null +++ b/pkg/adapter/session/sess_test.go @@ -0,0 +1,51 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go new file mode 100644 index 00000000..3d107198 --- /dev/null +++ b/pkg/adapter/session/sess_utils.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + return session.EncodeGob(obj) +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + return session.DecodeGob(encoded) +} diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go new file mode 100644 index 00000000..eea2f90e --- /dev/null +++ b/pkg/adapter/session/session.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error // set session value + Get(key interface{}) interface{} // get session value + Delete(key interface{}) error // delete session value + SessionID() string // back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error // delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int // get all active session + SessionGC() +} + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + session.Register(name, &oldToNewProviderAdapter{ + delegate: provide, + }) +} + +// GetProvider +func GetProvider(name string) (Provider, error) { + res, err := session.GetProvider(name) + if adt, ok := res.(*oldToNewProviderAdapter); err == nil && ok { + return adt.delegate, err + } + + return &newToOldProviderAdapter{ + delegate: res, + }, err +} + +// ManagerConfig define the session config +type ManagerConfig session.ManagerConfig + +// Manager contains Provider and its configuration. +type Manager session.Manager + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + m, err := session.NewManager(provideName, (*session.ManagerConfig)(cf)) + return (*Manager)(m), err +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return &newToOldProviderAdapter{ + delegate: (*session.Manager)(manager).GetProvider(), + } +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (Store, error) { + s, err := (*session.Manager)(manager).SessionStart(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + (*session.Manager)(manager).SessionDestroy(w, r) +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (Store, error) { + s, err := (*session.Manager)(manager).GetSessionStore(sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + (*session.Manager)(manager).GC() +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) Store { + s := (*session.Manager)(manager).SessionRegenerateID(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return (*session.Manager)(manager).GetActiveSession() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + (*session.Manager)(manager).SetSecure(secure) +} + +// Log implement the log.Logger +type Log session.Log + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + return (*Log)(session.NewSessionLog(out)) +} diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..aee3a364 --- /dev/null +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -0,0 +1,84 @@ +package ssdb + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeSsdb "github.com/astaxie/beego/pkg/infrastructure/session/ssdb" +) + +// Provider holds ssdb client and configs +type Provider beeSsdb.Provider + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + return (*beeSsdb.Provider)(p).SessionInit(context.Background(), maxLifetime, savePath) +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + res, _ := (*beeSsdb.Provider)(p).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + return (*beeSsdb.Provider)(p).SessionDestroy(context.Background(), sid) +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { + (*beeSsdb.Provider)(p).SessionGC(context.Background()) +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return (*beeSsdb.Provider)(p).SessionAll(context.Background()) +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore beeSsdb.SessionStore + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + return (*beeSsdb.SessionStore)(s).Set(context.Background(), key, value) +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + return (*beeSsdb.SessionStore)(s).Get(context.Background(), key) +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + return (*beeSsdb.SessionStore)(s).Delete(context.Background(), key) +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + return (*beeSsdb.SessionStore)(s).Flush(context.Background()) +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return (*beeSsdb.SessionStore)(s).SessionID(context.Background()) +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeSsdb.SessionStore)(s).SessionRelease(context.Background(), w) +} diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go new file mode 100644 index 00000000..c1a03c38 --- /dev/null +++ b/pkg/adapter/session/store_adapter.go @@ -0,0 +1,84 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type NewToOldStoreAdapter struct { + delegate session.Store +} + +func CreateNewToOldStoreAdapter(s session.Store) Store { + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +func (n *NewToOldStoreAdapter) Set(key, value interface{}) error { + return n.delegate.Set(context.Background(), key, value) +} + +func (n *NewToOldStoreAdapter) Get(key interface{}) interface{} { + return n.delegate.Get(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) Delete(key interface{}) error { + return n.delegate.Delete(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) SessionID() string { + return n.delegate.SessionID(context.Background()) +} + +func (n *NewToOldStoreAdapter) SessionRelease(w http.ResponseWriter) { + n.delegate.SessionRelease(context.Background(), w) +} + +func (n *NewToOldStoreAdapter) Flush() error { + return n.delegate.Flush(context.Background()) +} + +type oldToNewStoreAdapter struct { + delegate Store +} + +func (o *oldToNewStoreAdapter) Set(ctx context.Context, key, value interface{}) error { + return o.delegate.Set(key, value) +} + +func (o *oldToNewStoreAdapter) Get(ctx context.Context, key interface{}) interface{} { + return o.delegate.Get(key) +} + +func (o *oldToNewStoreAdapter) Delete(ctx context.Context, key interface{}) error { + return o.delegate.Delete(key) +} + +func (o *oldToNewStoreAdapter) SessionID(ctx context.Context) string { + return o.delegate.SessionID() +} + +func (o *oldToNewStoreAdapter) SessionRelease(ctx context.Context, w http.ResponseWriter) { + o.delegate.SessionRelease(w) +} + +func (o *oldToNewStoreAdapter) Flush(ctx context.Context) error { + return o.delegate.Flush() +} diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go index ffb19fb7..649f6510 100644 --- a/pkg/infrastructure/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -172,7 +172,7 @@ func (pder *CookieProvider) SessionAll(context.Context) int { } // SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { +func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error { return nil } diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go index 9a27c331..27e24c73 100644 --- a/pkg/infrastructure/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -96,7 +96,7 @@ func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, sav func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) + go pder.SessionUpdate(nil, sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil } @@ -123,7 +123,7 @@ func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, er func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) + go pder.SessionUpdate(nil, oldsid) pder.lock.RUnlock() pder.lock.Lock() element.Value.(*MemSessionStore).sid = sid @@ -181,7 +181,7 @@ func (pder *MemProvider) SessionAll(context.Context) int { } // SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { +func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { From 1dae2c9eb3fbe06c21639409190ec526d66e6e5e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 2 Sep 2020 22:44:31 +0800 Subject: [PATCH 145/207] Adapter: web module --- pkg/adapter/admin.go | 48 ++++ pkg/adapter/app.go | 261 ++++++++++++++++++++ pkg/adapter/beego.go | 75 ++++++ pkg/adapter/build_info.go | 27 +++ pkg/adapter/config.go | 179 ++++++++++++++ pkg/adapter/controller.go | 401 +++++++++++++++++++++++++++++++ pkg/adapter/error.go | 202 ++++++++++++++++ pkg/adapter/filter.go | 36 +++ pkg/adapter/flash.go | 63 +++++ pkg/adapter/fs.go | 35 +++ pkg/adapter/log.go | 129 ++++++++++ pkg/adapter/namespace.go | 378 +++++++++++++++++++++++++++++ pkg/adapter/policy.go | 57 +++++ pkg/adapter/router.go | 279 +++++++++++++++++++++ pkg/adapter/template.go | 108 +++++++++ pkg/adapter/templatefunc.go | 151 ++++++++++++ pkg/adapter/templatefunc_test.go | 304 +++++++++++++++++++++++ pkg/adapter/tree.go | 49 ++++ pkg/adapter/tree_test.go | 249 +++++++++++++++++++ pkg/server/web/app.go | 2 +- pkg/server/web/config.go | 3 + pkg/server/web/filter.go | 46 +++- pkg/server/web/router.go | 14 +- 23 files changed, 3080 insertions(+), 16 deletions(-) create mode 100644 pkg/adapter/admin.go create mode 100644 pkg/adapter/app.go create mode 100644 pkg/adapter/beego.go create mode 100644 pkg/adapter/build_info.go create mode 100644 pkg/adapter/config.go create mode 100644 pkg/adapter/controller.go create mode 100644 pkg/adapter/error.go create mode 100644 pkg/adapter/filter.go create mode 100644 pkg/adapter/flash.go create mode 100644 pkg/adapter/fs.go create mode 100644 pkg/adapter/log.go create mode 100644 pkg/adapter/namespace.go create mode 100644 pkg/adapter/policy.go create mode 100644 pkg/adapter/router.go create mode 100644 pkg/adapter/template.go create mode 100644 pkg/adapter/templatefunc.go create mode 100644 pkg/adapter/templatefunc_test.go create mode 100644 pkg/adapter/tree.go create mode 100644 pkg/adapter/tree_test.go diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go new file mode 100644 index 00000000..87e7259b --- /dev/null +++ b/pkg/adapter/admin.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + FilterMonitorFunc = web.FilterMonitorFunc +} + +// PrintTree prints all registered routers. +func PrintTree() M { + return (M)(web.PrintTree()) +} diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go new file mode 100644 index 00000000..64280a7b --- /dev/null +++ b/pkg/adapter/app.go @@ -0,0 +1,261 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + context2 "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = (*App)(web.BeeApp) +} + +// App defines beego application with a new PatternServeMux. +type App web.App + +// NewApp returns a new beego application. +func NewApp() *App { + return (*App)(web.NewApp()) +} + +// MiddleWare function for http.Handler +type MiddleWare web.MiddleWare + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + (*web.App)(app).Run(newMws...) +} + +func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { + newMws := make([]web.MiddleWare, 0, len(mws)) + for _, old := range mws { + newMws = append(newMws, (web.MiddleWare)(old)) + } + return newMws +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + return (*App)(web.Router(rootpath, c, mappingMethods...)) +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + return (*App)(web.UnregisterFixedRoute(fixedRoute, method)) +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +// } +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + newList := oldToNewCtrlIntfs(cList) + return (*App)(web.Include(newList...)) +} + +func oldToNewCtrlIntfs(cList []ControllerInterface) []web.ControllerInterface { + newList := make([]web.ControllerInterface, 0, len(cList)) + for _, c := range cList { + newList = append(newList, c) + } + return newList +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + return (*App)(web.RESTRouter(rootpath, c)) +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + return (*App)(web.AutoRouter(c)) +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + return (*App)(web.AutoPrefix(prefix, c)) +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + return (*App)(web.Get(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + return (*App)(web.Post(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + return (*App)(web.Delete(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + return (*App)(web.Put(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + return (*App)(web.Head(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + return (*App)(web.Options(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + return (*App)(web.Patch(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + return (*App)(web.Any(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + return (*App)(web.Handler(rootpath, h, options)) +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*context2.Context)(ctx)) + }, params...)) +} diff --git a/pkg/adapter/beego.go b/pkg/adapter/beego.go new file mode 100644 index 00000000..efd2d4ea --- /dev/null +++ b/pkg/adapter/beego.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + // VERSION represent beego web framework version. + VERSION = web.VERSION + + // DEV is for develop + DEV = web.DEV + // PROD is for production + PROD = web.PROD +) + +// M is Map shortcut +type M web.M + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + for _, f := range hf { + web.AddAPPStartHook(func() error { + return f() + }) + } +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + web.Run(params...) +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + web.RunWithMiddleWares(addr, newMws...) +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + web.TestBeegoInit(ap) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + web.InitBeegoBeforeTest(appConfigPath) +} diff --git a/pkg/adapter/build_info.go b/pkg/adapter/build_info.go new file mode 100644 index 00000000..1e8dacf0 --- /dev/null +++ b/pkg/adapter/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/adapter/config.go b/pkg/adapter/config.go new file mode 100644 index 00000000..1491722c --- /dev/null +++ b/pkg/adapter/config.go @@ -0,0 +1,179 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + context2 "context" + + "github.com/astaxie/beego/pkg/adapter/session" + newCfg "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/server/web" +) + +// Config is the main struct for BConfig +type Config web.Config + +// Listen holds for http and https related config +type Listen web.Listen + +// WebConfig holds web related config +type WebConfig web.WebConfig + +// SessionConfig holds session related config +type SessionConfig web.SessionConfig + +// LogConfig holds Log related config +type LogConfig web.LogConfig + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = (*Config)(web.BConfig) + AppPath = web.AppPath + + WorkPath = web.WorkPath + + AppConfig = &beegoAppConfig{innerConfig: (newCfg.Configer)(web.AppConfig)} +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + return web.LoadAppConfig(adapterName, configPath) +} + +type beegoAppConfig struct { + innerConfig newCfg.Configer +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(context2.Background(), BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(context2.Background(), key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v, err := b.innerConfig.String(context2.Background(), BConfig.RunMode+"::"+key); v != "" && err != nil { + return v + } + res, _ := b.innerConfig.String(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v, err := b.innerConfig.Strings(context2.Background(), BConfig.RunMode+"::"+key); len(v) > 0 && err != nil { + return v + } + res, _ := b.innerConfig.Strings(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int(context2.Background(), key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int64(context2.Background(), key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Bool(context2.Background(), key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Float(context2.Background(), key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(context2.Background(), key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(context2.Background(), section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(context2.Background(), filename) +} diff --git a/pkg/adapter/controller.go b/pkg/adapter/controller.go new file mode 100644 index 00000000..010add64 --- /dev/null +++ b/pkg/adapter/controller.go @@ -0,0 +1,401 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "mime/multipart" + "net/url" + + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/session" + webContext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = web.ErrAbort + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = web.GlobalControllerRouter +) + +// ControllerFilter store the filter for controller +type ControllerFilter web.ControllerFilter + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments web.ControllerFilterComments + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments web.ControllerImportComments + +// ControllerComments store the comment for the controller method +type ControllerComments web.ControllerComments + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice web.ControllerCommentsSlice + +func (p ControllerCommentsSlice) Len() int { + return (web.ControllerCommentsSlice)(p).Len() +} +func (p ControllerCommentsSlice) Less(i, j int) bool { + return (web.ControllerCommentsSlice)(p).Less(i, j) +} +func (p ControllerCommentsSlice) Swap(i, j int) { + (web.ControllerCommentsSlice)(p).Swap(i, j) +} + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller web.Controller + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface web.ControllerInterface + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + (*web.Controller)(c).Init((*webContext.Context)(ctx), controllerName, actionName, app) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() { + (*web.Controller)(c).Prepare() +} + +// Finish runs after request function execution. +func (c *Controller) Finish() { + (*web.Controller)(c).Finish() +} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + (*web.Controller)(c).Get() +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + (*web.Controller)(c).Post() +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + (*web.Controller)(c).Delete() +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + (*web.Controller)(c).Put() +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + (*web.Controller)(c).Head() +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + (*web.Controller)(c).Patch() +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + (*web.Controller)(c).Options() +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + (*web.Controller)(c).Trace() +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + return (*web.Controller)(c).HandlerFunc(fnname) +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() { + (*web.Controller)(c).URLMapping() +} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + (*web.Controller)(c).Mapping(method, fn) +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + return (*web.Controller)(c).Render() +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + return (*web.Controller)(c).RenderString() +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + return (*web.Controller)(c).RenderBytes() +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + (*web.Controller)(c).Redirect(url, code) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + (*web.Controller)(c).SetData(data) +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + (*web.Controller)(c).Abort(code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + (*web.Controller)(c).CustomAbort(status, body) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + (*web.Controller)(c).StopRun() +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + return (*web.Controller)(c).URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + (*web.Controller)(c).ServeJSON(encoding...) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + (*web.Controller)(c).ServeJSONP() +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + (*web.Controller)(c).ServeXML() +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + (*web.Controller)(c).ServeYAML() +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + (*web.Controller)(c).ServeFormatted(encoding...) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + return (*web.Controller)(c).Input() +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return (*web.Controller)(c).ParseForm(obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + return (*web.Controller)(c).GetString(key, def...) +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + return (*web.Controller)(c).GetStrings(key, def...) +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + return (*web.Controller)(c).GetInt(key, def...) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + return (*web.Controller)(c).GetInt8(key, def...) +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + return (*web.Controller)(c).GetUint8(key, def...) +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + return (*web.Controller)(c).GetInt16(key, def...) +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + return (*web.Controller)(c).GetUint16(key, def...) +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + return (*web.Controller)(c).GetInt32(key, def...) +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + return (*web.Controller)(c).GetUint32(key, def...) +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + return (*web.Controller)(c).GetInt64(key, def...) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + return (*web.Controller)(c).GetUint64(key, def...) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + return (*web.Controller)(c).GetBool(key, def...) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + return (*web.Controller)(c).GetFloat(key, def...) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return (*web.Controller)(c).GetFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + return (*web.Controller)(c).GetFiles(key) +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + return (*web.Controller)(c).SaveToFile(fromfile, tofile) +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + s := (*web.Controller)(c).StartSession() + return session.CreateNewToOldStoreAdapter(s) +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + (*web.Controller)(c).SetSession(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + return (*web.Controller)(c).GetSession(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + (*web.Controller)(c).DelSession(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + (*web.Controller)(c).SessionRegenerateID() +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + (*web.Controller)(c).DestroySession() +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return (*web.Controller)(c).IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return (*web.Controller)(c).GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*web.Controller)(c).SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + return (*web.Controller)(c).XSRFToken() +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + return (*web.Controller)(c).CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return (*web.Controller)(c).XSRFFormHTML() +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return (*web.Controller)(c).GetControllerAndAction() +} diff --git a/pkg/adapter/error.go b/pkg/adapter/error.go new file mode 100644 index 00000000..4f08aa8c --- /dev/null +++ b/pkg/adapter/error.go @@ -0,0 +1,202 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = web.ErrorMaps + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + return (*App)(web.ErrorHandler(code, h)) +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + return (*App)(web.ErrorController(c)) +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + web.Exception(errCode, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/filter.go b/pkg/adapter/filter.go new file mode 100644 index 00000000..cafed773 --- /dev/null +++ b/pkg/adapter/filter.go @@ -0,0 +1,36 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter web.FilterRouter + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + return (*web.FilterRouter)(f).ValidRouter(url, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go new file mode 100644 index 00000000..e5e1c187 --- /dev/null +++ b/pkg/adapter/flash.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData web.FlashData + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return (*FlashData)(web.NewFlash()) +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + (*web.FlashData)(fd).Set(key, msg, args) +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + (*web.FlashData)(fd).Success(msg, args...) +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + (*web.FlashData)(fd).Notice(msg, args...) +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + (*web.FlashData)(fd).Warning(msg, args...) +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + (*web.FlashData)(fd).Error(msg, args...) +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + (*web.FlashData)(fd).Store((*web.Controller)(c)) +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + return (*FlashData)(web.ReadFromRequest((*web.Controller)(c))) +} diff --git a/pkg/adapter/fs.go b/pkg/adapter/fs.go new file mode 100644 index 00000000..07054ca3 --- /dev/null +++ b/pkg/adapter/fs.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "path/filepath" + + "github.com/astaxie/beego/pkg/server/web" +) + +type FileSystem web.FileSystem + +func (d FileSystem) Open(name string) (http.File, error) { + return (web.FileSystem)(d).Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + return web.Walk(fs, root, walkFn) +} diff --git a/pkg/adapter/log.go b/pkg/adapter/log.go new file mode 100644 index 00000000..d9ff6e0c --- /dev/null +++ b/pkg/adapter/log.go @@ -0,0 +1,129 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "strings" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + + webLog "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = webLog.LevelEmergency + LevelAlert = webLog.LevelAlert + LevelCritical = webLog.LevelCritical + LevelError = webLog.LevelError + LevelWarning = webLog.LevelWarning + LevelNotice = webLog.LevelNotice + LevelInformational = webLog.LevelInformational + LevelDebug = webLog.LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/adapter/namespace.go b/pkg/adapter/namespace.go new file mode 100644 index 00000000..609402cf --- /dev/null +++ b/pkg/adapter/namespace.go @@ -0,0 +1,378 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + adtContext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +type namespaceCond func(*adtContext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace web.Namespace + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + nps := oldToNewLinkNs(params) + return (*Namespace)(web.NewNamespace(prefix, nps...)) +} + +func oldToNewLinkNs(params []LinkNamespace) []web.LinkNamespace { + nps := make([]web.LinkNamespace, 0, len(params)) + for _, p := range params { + nps = append(nps, func(namespace *web.Namespace) { + p((*Namespace)(namespace)) + }) + } + return nps +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + (*web.Namespace)(n).Cond(func(context *context.Context) bool { + return cond((*adtContext.Context)(context)) + }) + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + nfs := oldToNewFilter(filter) + (*web.Namespace)(n).Filter(action, nfs...) + return n +} + +func oldToNewFilter(filter []FilterFunc) []web.FilterFunc { + nfs := make([]web.FilterFunc, 0, len(filter)) + for _, f := range filter { + nfs = append(nfs, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } + return nfs +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + (*web.Namespace)(n).Router(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoRouter(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Get(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Delete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Put(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Head(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Options(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Patch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Any(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + (*web.Namespace)(n).Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + nL := oldToNewCtrlIntfs(cList) + (*web.Namespace)(n).Include(nL...) + return n +} + +// Namespace add nest Namespace +// usage: +// ns := beego.NewNamespace(“/v1”). +// Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +// ) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + nns := oldToNewNs(ns) + (*web.Namespace)(n).Namespace(nns...) + return n +} + +func oldToNewNs(ns []*Namespace) []*web.Namespace { + nns := make([]*web.Namespace, 0, len(ns)) + for _, n := range ns { + nns = append(nns, (*web.Namespace)(n)) + } + return nns +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + nnl := oldToNewNs(nl) + web.AddNamespace(nnl...) +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(namespace *Namespace) { + web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSBefore(nfs...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSAfter(nfs...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewCtrlIntfs(cList) + web.NSInclude(nfs...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(namespace *Namespace) { + web.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + nps := oldToNewLinkNs(params) + web.NSNamespace(prefix, nps...) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + web.NSHandler(rootpath, h) + } +} diff --git a/pkg/adapter/policy.go b/pkg/adapter/policy.go new file mode 100644 index 00000000..f3759c76 --- /dev/null +++ b/pkg/adapter/policy.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + pf := (*web.ControllerRegister)(p).FindPolicy((*beecontext.Context)(cont)) + npf := newToOldPolicyFunc(pf) + return npf +} + +func newToOldPolicyFunc(pf []web.PolicyFunc) []PolicyFunc { + npf := make([]PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *context.Context) { + f((*beecontext.Context)(c)) + }) + } + return npf +} + +func oldToNewPolicyFunc(pf []PolicyFunc) []web.PolicyFunc { + npf := make([]web.PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *beecontext.Context) { + f((*context.Context)(c)) + }) + } + return npf +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + pf := oldToNewPolicyFunc(policy) + web.Policy(pattern, method, pf...) +} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go new file mode 100644 index 00000000..5a36fbee --- /dev/null +++ b/pkg/adapter/router.go @@ -0,0 +1,279 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "time" + + beecontext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// default filter execution points +const ( + BeforeStatic = web.BeforeStatic + BeforeRouter = web.BeforeRouter + BeforeExec = web.BeforeExec + AfterExec = web.AfterExec + FinishRouter = web.FinishRouter +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = web.HTTPMETHOD + + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &newToOldFtHdlAdapter{ + delegate: web.DefaultAccessLogFilter, + } +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +type newToOldFtHdlAdapter struct { + delegate web.FilterHandler +} + +func (n *newToOldFtHdlAdapter) Filter(ctx *beecontext.Context) bool { + return n.delegate.Filter((*context.Context)(ctx)) +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + web.ExceptMethodAppend(action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo web.ControllerInfo + +func (c *ControllerInfo) GetPattern() string { + return (*web.ControllerInfo)(c).GetPattern() +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister web.ControllerRegister + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return (*ControllerRegister)(web.NewControllerRegister()) +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + nls := oldToNewCtrlIntfs(cList) + (*web.ControllerRegister)(p).Include(nls...) +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return (*beecontext.Context)((*web.ControllerRegister)(p).GetContext()) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + (*web.ControllerRegister)(p).GiveBackContext((*context.Context)(ctx)) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Get(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Post(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Put(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Delete(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Head(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Patch(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Options(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Any(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).AddMethod(method, pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + (*web.ControllerRegister)(p).Handler(pattern, h, options) +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + (*web.ControllerRegister)(p).AddAuto(c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + (*web.ControllerRegister)(p).AddAutoPrefix(prefix, c) +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + opts := oldToNewFilterOpts(params) + return (*web.ControllerRegister)(p).InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*beecontext.Context)(ctx)) + }, opts...) +} + +func oldToNewFilterOpts(params []bool) []web.FilterOpt { + opts := make([]web.FilterOpt, 0, 4) + if len(params) > 0 { + opts = append(opts, web.WithReturnOnOutput(params[0])) + } + if len(params) > 1 { + opts = append(opts, web.WithResetParams(params[1])) + } + return opts +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + return (*web.ControllerRegister)(p).URLFor(endpoint, values...) +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + (*web.ControllerRegister)(p).ServeHTTP(rw, r) +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(ctx *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + r, ok := (*web.ControllerRegister)(p).FindRouter((*context.Context)(ctx)) + return (*ControllerInfo)(r), ok +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + web.LogAccess((*context.Context)(ctx), startTime, statusCode) +} diff --git a/pkg/adapter/template.go b/pkg/adapter/template.go new file mode 100644 index 00000000..1f943caf --- /dev/null +++ b/pkg/adapter/template.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "io" + "net/http" + + "github.com/astaxie/beego/pkg/server/web" +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return web.ExecuteTemplate(wr, name, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + return web.ExecuteViewPathTemplate(wr, name, viewPath, data) +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + return web.AddFuncMap(key, fn) +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + return web.HasTemplateExt(paths) +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + web.AddTemplateExt(ext) +} + +// AddViewPath adds a new path to the supported view paths. +// Can later be used by setting a controller ViewPath to this folder +// will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + return web.AddViewPath(viewPath) +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + return web.BuildTemplate(dir, files...) +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + web.SetTemplateFSFunc(func() http.FileSystem { + return fnt() + }) +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + return (*App)(web.SetViewsPath(path)) +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + return (*App)(web.SetStaticPath(url, path)) +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + return (*App)(web.DelStaticPath(url)) +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + return (*App)(web.AddTemplateEngine(extension, func(root, path string, funcs template.FuncMap) (*template.Template, error) { + return fn(root, path, funcs) + })) +} diff --git a/pkg/adapter/templatefunc.go b/pkg/adapter/templatefunc.go new file mode 100644 index 00000000..5130d590 --- /dev/null +++ b/pkg/adapter/templatefunc.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + return web.Substr(s, start, length) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + return web.HTML2str(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + return web.DateFormat(t, layout) +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + return web.DateParse(dateString, format) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + return web.Date(t, format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + return web.Compare(a, b) +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return web.CompareNot(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return web.NotNil(a) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (interface{}, error) { + return web.GetConfig(returnType, key, defaultVal) +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return web.Str2html(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + return web.Htmlquote(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + return web.Htmlunquote(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return web.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + return web.AssetsJs(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + return web.ParseForm(form, obj) +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + return web.RenderForm(obj) +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + return web.MapGet(arg1, arg2...) +} diff --git a/pkg/adapter/templatefunc_test.go b/pkg/adapter/templatefunc_test.go new file mode 100644 index 00000000..f5113606 --- /dev/null +++ b/pkg/adapter/templatefunc_test.go @@ -0,0 +1,304 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/adapter/tree.go b/pkg/adapter/tree.go new file mode 100644 index 00000000..2e3cd0d0 --- /dev/null +++ b/pkg/adapter/tree.go @@ -0,0 +1,49 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree web.Tree + +// NewTree return a new Tree +func NewTree() *Tree { + return (*Tree)(web.NewTree()) +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + (*web.Tree)(t).AddTree(prefix, (*web.Tree)(tree)) +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + (*web.Tree)(t).AddRouter(pattern, runObject) +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + return (*web.Tree)(t).Match(pattern, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/tree_test.go b/pkg/adapter/tree_test.go new file mode 100644 index 00000000..309ed072 --- /dev/null +++ b/pkg/adapter/tree_test.go @@ -0,0 +1,249 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index e61084a5..ad3ff663 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -199,7 +199,7 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: tls.ClientAuthType(BConfig.Listen.ClientAuth), } } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index bf8db30e..6e69a2fb 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -16,6 +16,7 @@ package web import ( context2 "context" + "crypto/tls" "fmt" "os" "path/filepath" @@ -72,6 +73,7 @@ type Listen struct { AdminPort int EnableFcgi bool EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + ClientAuth int } // WebConfig holds web related config @@ -234,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: int(tls.RequireAndVerifyClientCert), }, WebConfig: WebConfig{ AutoRender: true, diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index 8d3acb24..e10faafc 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -43,24 +43,26 @@ type FilterRouter struct { // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { +func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ tree: NewTree(), pattern: pattern, filterFunc: filter, returnOnOutput: true, } - if !routerCaseSensitive { + + fos := &filterOpts{} + + for _, o := range opts { + o(fos) + } + + if !fos.routerCaseSensitive { mr.pattern = strings.ToLower(pattern) } - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } + mr.returnOnOutput = fos.returnOnOutput + mr.resetParams = fos.resetParams mr.tree.AddRouter(pattern, true) return mr } @@ -103,3 +105,29 @@ func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { } return false } + +type filterOpts struct { + returnOnOutput bool + resetParams bool + routerCaseSensitive bool +} + +type FilterOpt func(opts *filterOpts) + +func WithReturnOnOutput(ret bool) FilterOpt { + return func(opts *filterOpts) { + opts.returnOnOutput = ret + } +} + +func WithResetParams(reset bool) FilterOpt { + return func(opts *filterOpts) { + opts.resetParams = reset + } +} + +func WithCaseSensitive(sensitive bool) FilterOpt { + return func(opts *filterOpts) { + opts.routerCaseSensitive = sensitive + } +} diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index c3eddd29..3dd19a6f 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -148,7 +148,7 @@ func NewControllerRegister() *ControllerRegister { }, }, } - res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } @@ -262,7 +262,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) @@ -452,8 +452,9 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + mr := newFilterRouter(pattern, filter, opts...) return p.insertFilterRouter(pos, mr) } @@ -468,10 +469,11 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { root := p.chainRoot filterFunc := chain(root.filterFunc) - p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) p.chainRoot.next = root } From f1950482c2c0ee8e6e90ad320245f7130ab9cf4e Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:54:05 +0800 Subject: [PATCH 146/207] Adapter: plugin --- pkg/adapter/plugins/apiauth/apiauth.go | 94 ++++++++ pkg/adapter/plugins/apiauth/apiauth_test.go | 20 ++ pkg/adapter/plugins/auth/basic.go | 81 +++++++ pkg/adapter/plugins/authz/authz.go | 80 +++++++ pkg/adapter/plugins/authz/authz_model.conf | 14 ++ pkg/adapter/plugins/authz/authz_policy.csv | 7 + pkg/adapter/plugins/authz/authz_test.go | 108 +++++++++ pkg/adapter/plugins/cors/cors.go | 71 ++++++ pkg/adapter/plugins/cors/cors_test.go | 253 ++++++++++++++++++++ pkg/server/web/filter/apiauth/apiauth.go | 5 - 10 files changed, 728 insertions(+), 5 deletions(-) create mode 100644 pkg/adapter/plugins/apiauth/apiauth.go create mode 100644 pkg/adapter/plugins/apiauth/apiauth_test.go create mode 100644 pkg/adapter/plugins/auth/basic.go create mode 100644 pkg/adapter/plugins/authz/authz.go create mode 100644 pkg/adapter/plugins/authz/authz_model.conf create mode 100644 pkg/adapter/plugins/authz/authz_policy.csv create mode 100644 pkg/adapter/plugins/authz/authz_test.go create mode 100644 pkg/adapter/plugins/cors/cors.go create mode 100644 pkg/adapter/plugins/cors/cors_test.go diff --git a/pkg/adapter/plugins/apiauth/apiauth.go b/pkg/adapter/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..ed43f8a0 --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "net/url" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/apiauth" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret apiauth.AppIDToAppSecret + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + f := apiauth.APIBasicAuth(appid, appkey) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + ft := apiauth.APISecretAuth(apiauth.AppIDToAppSecret(f), timeout) + return func(ctx *context.Context) { + ft((*beecontext.Context)(ctx)) + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, requestURL string) string { + return apiauth.Signature(appsecret, method, params, requestURL) +} diff --git a/pkg/adapter/plugins/apiauth/apiauth_test.go b/pkg/adapter/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/adapter/plugins/auth/basic.go b/pkg/adapter/plugins/auth/basic.go new file mode 100644 index 00000000..7a9cd326 --- /dev/null +++ b/pkg/adapter/plugins/auth/basic.go @@ -0,0 +1,81 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "net/http" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/auth" +) + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + return func(c *context.Context) { + f := auth.Basic(username, password) + f((*beecontext.Context)(c)) + } +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, realm string) beego.FilterFunc { + f := auth.NewBasicAuthenticator(auth.SecretProvider(secrets), realm) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider auth.SecretProvider + +// BasicAuth store the SecretProvider and Realm +type BasicAuth auth.BasicAuth + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + return (*auth.BasicAuth)(a).CheckAuth(r) +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + (*auth.BasicAuth)(a).RequireAuth(w, r) +} diff --git a/pkg/adapter/plugins/authz/authz.go b/pkg/adapter/plugins/authz/authz.go new file mode 100644 index 00000000..c38be9cb --- /dev/null +++ b/pkg/adapter/plugins/authz/authz.go @@ -0,0 +1,80 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "net/http" + + "github.com/casbin/casbin" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/authz" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + f := authz.NewAuthorizer(e) + return func(context *context.Context) { + f((*beecontext.Context)(context)) + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer authz.BasicAuthorizer + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + return (*authz.BasicAuthorizer)(a).GetUserName(r) +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + return (*authz.BasicAuthorizer)(a).CheckPermission(r) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + (*authz.BasicAuthorizer)(a).RequirePermission(w) +} diff --git a/pkg/adapter/plugins/authz/authz_model.conf b/pkg/adapter/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_policy.csv b/pkg/adapter/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/adapter/plugins/authz/authz_test.go b/pkg/adapter/plugins/authz/authz_test.go new file mode 100644 index 00000000..ddbda5f4 --- /dev/null +++ b/pkg/adapter/plugins/authz/authz_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "net/http" + "net/http/httptest" + "testing" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/plugins/auth" + "github.com/casbin/casbin" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/adapter/plugins/cors/cors.go b/pkg/adapter/plugins/cors/cors.go new file mode 100644 index 00000000..65af8b8f --- /dev/null +++ b/pkg/adapter/plugins/cors/cors.go @@ -0,0 +1,71 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + beego "github.com/astaxie/beego/pkg/adapter" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/cors" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +// Options represents Access Control options. +type Options cors.Options + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + return (*cors.Options)(o).Header(origin) +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + return (*cors.Options)(o).PreflightHeader(origin, rMethod, rHeaders) +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) bool { + return (*cors.Options)(o).IsOriginAllowed(origin) +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + f := cors.Allow((*cors.Options)(opts)) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} diff --git a/pkg/adapter/plugins/cors/cors_test.go b/pkg/adapter/plugins/cors/cors_test.go new file mode 100644 index 00000000..34039143 --- /dev/null +++ b/pkg/adapter/plugins/cors/cors_test.go @@ -0,0 +1,253 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cors + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { + *httptest.ResponseRecorder + savedHeaderMap http.Header +} + +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} +} + +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { + gr.ResponseRecorder.WriteHeader(code) + gr.savedHeaderMap = gr.ResponseRecorder.Header() +} + +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { + if gr.savedHeaderMap != nil { + // headers were written. clone so we don't get updates + clone := make(http.Header) + for k, v := range gr.savedHeaderMap { + clone[k] = v + } + return clone + } + return gr.ResponseRecorder.Header() +} + +func Test_AllowAll(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { + t.Errorf("Allow-Origin header should be *") + } +} + +func Test_AllowRegexMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://bar.foo.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != origin { + t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) + } +} + +func Test_AllowRegexNoMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://ww.foo.com.evil.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != "" { + t.Errorf("Allow-Origin header should not exist, found %v", headerValue) + } +} + +func Test_OtherHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + ExposeHeaders: []string{"Content-Length", "Hello"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) + methodsVal := recorder.HeaderMap.Get(headerAllowMethods) + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) + maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) + + if credentialsVal != "true" { + t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) + } + + if methodsVal != "PATCH,GET" { + t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) + } + + if headersVal != "Origin,X-whatever" { + t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) + } + + if exposedHeadersVal != "Content-Length,Hello" { + t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) + } + + if maxAgeVal != "300" { + t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) + } +} + +func Test_DefaultAllowHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + if headersVal != "Origin,Accept,Content-Type,Authorization" { + t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) + } +} + +func Test_Preflight(t *testing.T) { + recorder := NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowMethods: []string{"PUT", "PATCH"}, + AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, + })) + + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + r, _ := http.NewRequest("OPTIONS", "/foo", nil) + r.Header.Add(headerRequestMethod, "PUT") + r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") + handler.ServeHTTP(recorder, r) + + headers := recorder.Header() + methodsVal := headers.Get(headerAllowMethods) + headersVal := headers.Get(headerAllowHeaders) + originVal := headers.Get(headerAllowOrigin) + + if methodsVal != "PUT,PATCH" { + t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) + } + + if !strings.Contains(headersVal, "X-whatever") { + t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) + } + + if !strings.Contains(headersVal, "x-casesensitive") { + t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) + } + + if originVal != "*" { + t.Errorf("Allow-Origin is expected to be *, found %v", originVal) + } + + if recorder.Code != http.StatusOK { + t.Errorf("Status code is expected to be 200, found %d", recorder.Code) + } +} + +func Benchmark_WithoutCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} + +func Benchmark_WithCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} diff --git a/pkg/server/web/filter/apiauth/apiauth.go b/pkg/server/web/filter/apiauth/apiauth.go index ba56030b..8944db63 100644 --- a/pkg/server/web/filter/apiauth/apiauth.go +++ b/pkg/server/web/filter/apiauth/apiauth.go @@ -83,11 +83,6 @@ func APIBasicAuth(appid, appkey string) web.FilterFunc { return APISecretAuth(ft, 300) } -// APIBasicAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) web.FilterFunc { - return APIBasicAuth(appid, appkey) -} - // APISecretAuth uses AppIdToAppSecret verify and func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc { return func(ctx *context.Context) { From f6c95ad5346e77ebf0ade03489e3080d62a76e0f Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:56:56 +0800 Subject: [PATCH 147/207] Adapter: swagger module --- pkg/adapter/swagger/swagger.go | 68 ++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 pkg/adapter/swagger/swagger.go diff --git a/pkg/adapter/swagger/swagger.go b/pkg/adapter/swagger/swagger.go new file mode 100644 index 00000000..214959d9 --- /dev/null +++ b/pkg/adapter/swagger/swagger.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +import ( + "github.com/astaxie/beego/pkg/server/web/swagger" +) + +// Swagger list the resource +type Swagger swagger.Swagger + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information swagger.Information + +// Contact information for the exposed API. +type Contact swagger.Contact + +// License information for the exposed API. +type License swagger.License + +// Item Describes the operations available on a single path. +type Item swagger.Item + +// Operation Describes a single API operation on a path. +type Operation swagger.Operation + +// Parameter Describes a single operation parameter. +type Parameter swagger.Parameter + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems swagger.ParameterItems + +// Schema Object allows the definition of input and output data types. +type Schema swagger.Schema + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie swagger.Propertie + +// Response as they are returned from executing this operation. +type Response swagger.Response + +// Security Allows the definition of a security scheme that can be used by the operations +type Security swagger.Security + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag swagger.Tag + +// ExternalDocs include Additional external documentation +type ExternalDocs swagger.ExternalDocs From 35f1bd211929cb32e9dccdc82420782d25a2804f Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 16:58:49 +0800 Subject: [PATCH 148/207] Adapter: testing --- pkg/adapter/testing/client.go | 50 +++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 pkg/adapter/testing/client.go diff --git a/pkg/adapter/testing/client.go b/pkg/adapter/testing/client.go new file mode 100644 index 00000000..688aa6f3 --- /dev/null +++ b/pkg/adapter/testing/client.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/pkg/client/httplib/testing" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest testing.TestHTTPRequest + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Get(path)) +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Post(path)) +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Put(path)) +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Delete(path)) +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Head(path)) +} From f4a43814bec6e005d92de5a91aaaa513482e0f9d Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 5 Sep 2020 18:07:42 +0800 Subject: [PATCH 149/207] Adapter: utils --- pkg/adapter/cache/cache.go | 85 +++++++++ pkg/adapter/cache/cache_test.go | 191 +++++++++++++++++++++ pkg/adapter/utils/caller.go | 24 +++ pkg/adapter/utils/caller_test.go | 28 +++ pkg/adapter/utils/captcha/LICENSE | 19 ++ pkg/adapter/utils/captcha/README.md | 45 +++++ pkg/adapter/utils/captcha/captcha.go | 124 +++++++++++++ pkg/adapter/utils/captcha/image.go | 35 ++++ pkg/adapter/utils/captcha/image_test.go | 58 +++++++ pkg/adapter/utils/debug.go | 34 ++++ pkg/adapter/utils/debug_test.go | 46 +++++ pkg/adapter/utils/file.go | 47 +++++ pkg/adapter/utils/file_test.go | 75 ++++++++ pkg/adapter/utils/mail.go | 63 +++++++ pkg/adapter/utils/mail_test.go | 41 +++++ pkg/adapter/utils/pagination/controller.go | 26 +++ pkg/adapter/utils/pagination/doc.go | 58 +++++++ pkg/adapter/utils/pagination/paginator.go | 112 ++++++++++++ pkg/adapter/utils/rand.go | 24 +++ pkg/adapter/utils/rand_test.go | 33 ++++ pkg/adapter/utils/safemap.go | 58 +++++++ pkg/adapter/utils/safemap_test.go | 89 ++++++++++ pkg/adapter/utils/slice.go | 101 +++++++++++ pkg/adapter/utils/slice_test.go | 29 ++++ pkg/adapter/utils/utils.go | 10 ++ 25 files changed, 1455 insertions(+) create mode 100644 pkg/adapter/cache/cache.go create mode 100644 pkg/adapter/cache/cache_test.go create mode 100644 pkg/adapter/utils/caller.go create mode 100644 pkg/adapter/utils/caller_test.go create mode 100644 pkg/adapter/utils/captcha/LICENSE create mode 100644 pkg/adapter/utils/captcha/README.md create mode 100644 pkg/adapter/utils/captcha/captcha.go create mode 100644 pkg/adapter/utils/captcha/image.go create mode 100644 pkg/adapter/utils/captcha/image_test.go create mode 100644 pkg/adapter/utils/debug.go create mode 100644 pkg/adapter/utils/debug_test.go create mode 100644 pkg/adapter/utils/file.go create mode 100644 pkg/adapter/utils/file_test.go create mode 100644 pkg/adapter/utils/mail.go create mode 100644 pkg/adapter/utils/mail_test.go create mode 100644 pkg/adapter/utils/pagination/controller.go create mode 100644 pkg/adapter/utils/pagination/doc.go create mode 100644 pkg/adapter/utils/pagination/paginator.go create mode 100644 pkg/adapter/utils/rand.go create mode 100644 pkg/adapter/utils/rand_test.go create mode 100644 pkg/adapter/utils/safemap.go create mode 100644 pkg/adapter/utils/safemap_test.go create mode 100644 pkg/adapter/utils/slice.go create mode 100644 pkg/adapter/utils/slice_test.go create mode 100644 pkg/adapter/utils/utils.go diff --git a/pkg/adapter/cache/cache.go b/pkg/adapter/cache/cache.go new file mode 100644 index 00000000..21bb9141 --- /dev/null +++ b/pkg/adapter/cache/cache.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + + "github.com/astaxie/beego/pkg/client/cache" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache cache.Cache + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/adapter/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/adapter/utils/caller.go b/pkg/adapter/utils/caller.go new file mode 100644 index 00000000..d4fcc456 --- /dev/null +++ b/pkg/adapter/utils/caller.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return utils.GetFuncName(i) +} diff --git a/pkg/adapter/utils/caller_test.go b/pkg/adapter/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/adapter/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/adapter/utils/captcha/LICENSE b/pkg/adapter/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/adapter/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/adapter/utils/captcha/README.md b/pkg/adapter/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/adapter/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/adapter/utils/captcha/captcha.go b/pkg/adapter/utils/captcha/captcha.go new file mode 100644 index 00000000..faadc8bf --- /dev/null +++ b/pkg/adapter/utils/captcha/captcha.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "html/template" + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/captcha" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/pkg/adapter/context" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha captcha.Captcha + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + (*captcha.Captcha)(c).Handler((*beecontext.Context)(ctx)) +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + return (*captcha.Captcha)(c).CreateCaptchaHTML() +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + return (*captcha.Captcha)(c).CreateCaptcha() +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + return (*captcha.Captcha)(c).VerifyReq(req) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + return (*captcha.Captcha)(c).Verify(id, challenge) +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewCaptcha(urlPrefix, store)) +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewWithFilter(urlPrefix, store)) +} diff --git a/pkg/adapter/utils/captcha/image.go b/pkg/adapter/utils/captcha/image.go new file mode 100644 index 00000000..9979db84 --- /dev/null +++ b/pkg/adapter/utils/captcha/image.go @@ -0,0 +1,35 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "io" + + "github.com/astaxie/beego/pkg/server/web/captcha" +) + +// Image struct +type Image captcha.Image + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + return (*Image)(captcha.NewImage(digits, width, height)) +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + return (*captcha.Image)(m).WriteTo(w) +} diff --git a/pkg/adapter/utils/captcha/image_test.go b/pkg/adapter/utils/captcha/image_test.go new file mode 100644 index 00000000..bce2134a --- /dev/null +++ b/pkg/adapter/utils/captcha/image_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/utils" +) + +const ( + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/adapter/utils/debug.go b/pkg/adapter/utils/debug.go new file mode 100644 index 00000000..d39f3d3e --- /dev/null +++ b/pkg/adapter/utils/debug.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Display print the data in console +func Display(data ...interface{}) { + utils.Display(data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return utils.GetDisplayString(data...) +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + return utils.Stack(skip, indent) +} diff --git a/pkg/adapter/utils/debug_test.go b/pkg/adapter/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/adapter/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/adapter/utils/file.go b/pkg/adapter/utils/file.go new file mode 100644 index 00000000..8979389e --- /dev/null +++ b/pkg/adapter/utils/file.go @@ -0,0 +1,47 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + return utils.SelfPath() +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return utils.SelfDir() +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + return utils.FileExists(name) +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + return utils.SearchFile(filename, paths...) +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + return utils.GrepFile(patten, filename) +} diff --git a/pkg/adapter/utils/file_test.go b/pkg/adapter/utils/file_test.go new file mode 100644 index 00000000..b2644157 --- /dev/null +++ b/pkg/adapter/utils/file_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + "reflect" + "testing" +) + +var noExistedFile = "/tmp/not_existed_file" + +func TestSelfPath(t *testing.T) { + path := SelfPath() + if path == "" { + t.Error("path cannot be empty") + } + t.Logf("SelfPath: %s", path) +} + +func TestSelfDir(t *testing.T) { + dir := SelfDir() + t.Logf("SelfDir: %s", dir) +} + +func TestFileExists(t *testing.T) { + if !FileExists("./file.go") { + t.Errorf("./file.go should exists, but it didn't") + } + + if FileExists(noExistedFile) { + t.Errorf("Weird, how could this file exists: %s", noExistedFile) + } +} + +func TestSearchFile(t *testing.T) { + path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) + if err != nil { + t.Error(err) + } + t.Log(path) + + _, err = SearchFile(noExistedFile, ".") + if err == nil { + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) + } +} + +func TestGrepFile(t *testing.T) { + _, err := GrepFile("", noExistedFile) + if err == nil { + t.Error("expect file-not-existed error, but got nothing") + } + + path := filepath.Join(".", "testdata", "grepe.test") + lines, err := GrepFile(`^\s*[^#]+`, path) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { + t.Errorf("expect [hello world], but receive %v", lines) + } +} diff --git a/pkg/adapter/utils/mail.go b/pkg/adapter/utils/mail.go new file mode 100644 index 00000000..35a58756 --- /dev/null +++ b/pkg/adapter/utils/mail.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io" + + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Email is the type used for email messages +type Email utils.Email + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment utils.Attachment + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + return (*Email)(utils.NewEMail(config)) +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + return (*utils.Email)(e).Bytes() +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).AttachFile(args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).Attach(r, filename, args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Send will send out the mail +func (e *Email) Send() error { + return (*utils.Email)(e).Send() +} diff --git a/pkg/adapter/utils/mail_test.go b/pkg/adapter/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/adapter/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/adapter/utils/pagination/controller.go b/pkg/adapter/utils/pagination/controller.go new file mode 100644 index 00000000..a908d8b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/pagination" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(ctx *context.Context, per int, nums int64) (paginator *Paginator) { + return (*Paginator)(pagination.SetPaginator((*beecontext.Context)(ctx), per, nums)) +} diff --git a/pkg/adapter/utils/pagination/doc.go b/pkg/adapter/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/adapter/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/adapter/utils/pagination/paginator.go b/pkg/adapter/utils/pagination/paginator.go new file mode 100644 index 00000000..4bd4a1b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/paginator.go @@ -0,0 +1,112 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" +) + +// Paginator within the state of a http request. +type Paginator pagination.Paginator + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + return (*pagination.Paginator)(p).PageNums() +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return (*pagination.Paginator)(p).Nums() +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + (*pagination.Paginator)(p).SetNums(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + return (*pagination.Paginator)(p).Page() +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + return (*pagination.Paginator)(p).Pages() +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + return (*pagination.Paginator)(p).PageLink(page) +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + return (*pagination.Paginator)(p).PageLinkPrev() +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + return (*pagination.Paginator)(p).PageLinkNext() +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return (*pagination.Paginator)(p).PageLinkFirst() +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return (*pagination.Paginator)(p).PageLinkLast() +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return (*pagination.Paginator)(p).HasPrev() +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return (*pagination.Paginator)(p).HasNext() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return (*pagination.Paginator)(p).IsActive(page) +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (*pagination.Paginator)(p).Offset() +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return (*pagination.Paginator)(p).HasPages() +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + return (*Paginator)(pagination.NewPaginator(req, per, nums)) +} diff --git a/pkg/adapter/utils/rand.go b/pkg/adapter/utils/rand.go new file mode 100644 index 00000000..ae415cf3 --- /dev/null +++ b/pkg/adapter/utils/rand.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + return utils.RandomCreateBytes(n, alphabets...) +} diff --git a/pkg/adapter/utils/rand_test.go b/pkg/adapter/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/adapter/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/adapter/utils/safemap.go b/pkg/adapter/utils/safemap.go new file mode 100644 index 00000000..13e7bb46 --- /dev/null +++ b/pkg/adapter/utils/safemap.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// BeeMap is a map with lock +type BeeMap utils.BeeMap + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return (*BeeMap)(utils.NewBeeMap()) +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + return (*utils.BeeMap)(m).Get(k) +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + return (*utils.BeeMap)(m).Set(k, v) +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + return (*utils.BeeMap)(m).Check(k) +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + (*utils.BeeMap)(m).Delete(k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + return (*utils.BeeMap)(m).Items() +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + return (*utils.BeeMap)(m).Count() +} diff --git a/pkg/adapter/utils/safemap_test.go b/pkg/adapter/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/adapter/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/adapter/utils/slice.go b/pkg/adapter/utils/slice.go new file mode 100644 index 00000000..24d19ad2 --- /dev/null +++ b/pkg/adapter/utils/slice.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + return utils.InSlice(v, sl) +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + return utils.InSliceIface(v, sl) +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + return utils.SliceRandList(min, max) +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + return utils.SliceMerge(slice1, slice2) +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + return utils.SliceReduce(slice, func(i interface{}) interface{} { + return a(i) + }) +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + return utils.SliceRand(a) +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + return utils.SliceSum(intslice) +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + return utils.SliceFilter(slice, func(i interface{}) bool { + return a(i) + }) +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceDiff(slice1, slice2) +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceIntersect(slice1, slice2) +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + return utils.SliceChunk(slice, size) +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + return utils.SliceRange(start, end, step) +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + return utils.SlicePad(slice, size, val) +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + return utils.SliceUnique(slice) +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + return utils.SliceShuffle(slice) +} diff --git a/pkg/adapter/utils/slice_test.go b/pkg/adapter/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/adapter/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/adapter/utils/utils.go b/pkg/adapter/utils/utils.go new file mode 100644 index 00000000..1f3bcd31 --- /dev/null +++ b/pkg/adapter/utils/utils.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + return utils.GetGOPATHs() +} From 5b3dd7e50f4fde914c2a919ce35f77b4d5c19fe0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 6 Sep 2020 13:33:43 +0800 Subject: [PATCH 150/207] Adapter: orm --- pkg/adapter/orm/cmd.go | 28 ++ pkg/adapter/orm/db.go | 24 + pkg/adapter/orm/db_alias.go | 124 +++++ pkg/adapter/orm/models.go | 25 + pkg/adapter/orm/models_boot.go | 40 ++ pkg/adapter/orm/models_fields.go | 625 ++++++++++++++++++++++++ pkg/adapter/orm/orm.go | 314 ++++++++++++ pkg/adapter/orm/orm_conds.go | 83 ++++ pkg/adapter/orm/orm_log.go | 32 ++ pkg/adapter/orm/orm_queryset.go | 32 ++ pkg/adapter/orm/qb.go | 27 + pkg/adapter/orm/qb_mysql.go | 150 ++++++ pkg/adapter/orm/qb_tidb.go | 147 ++++++ pkg/adapter/orm/query_setter_adapter.go | 34 ++ pkg/adapter/orm/types.go | 150 ++++++ pkg/adapter/orm/utils.go | 286 +++++++++++ pkg/adapter/orm/utils_test.go | 70 +++ pkg/client/orm/db_alias.go | 60 ++- pkg/client/orm/orm.go | 6 +- 19 files changed, 2227 insertions(+), 30 deletions(-) create mode 100644 pkg/adapter/orm/cmd.go create mode 100644 pkg/adapter/orm/db.go create mode 100644 pkg/adapter/orm/db_alias.go create mode 100644 pkg/adapter/orm/models.go create mode 100644 pkg/adapter/orm/models_boot.go create mode 100644 pkg/adapter/orm/models_fields.go create mode 100644 pkg/adapter/orm/orm.go create mode 100644 pkg/adapter/orm/orm_conds.go create mode 100644 pkg/adapter/orm/orm_log.go create mode 100644 pkg/adapter/orm/orm_queryset.go create mode 100644 pkg/adapter/orm/qb.go create mode 100644 pkg/adapter/orm/qb_mysql.go create mode 100644 pkg/adapter/orm/qb_tidb.go create mode 100644 pkg/adapter/orm/query_setter_adapter.go create mode 100644 pkg/adapter/orm/types.go create mode 100644 pkg/adapter/orm/utils.go create mode 100644 pkg/adapter/orm/utils_test.go diff --git a/pkg/adapter/orm/cmd.go b/pkg/adapter/orm/cmd.go new file mode 100644 index 00000000..6fee237c --- /dev/null +++ b/pkg/adapter/orm/cmd.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + orm.RunCommand() +} + +func RunSyncdb(name string, force bool, verbose bool) error { + return orm.RunSyncdb(name, force, verbose) +} diff --git a/pkg/adapter/orm/db.go b/pkg/adapter/orm/db.go new file mode 100644 index 00000000..74bca8c0 --- /dev/null +++ b/pkg/adapter/orm/db.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = orm.ErrMissPK +) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go new file mode 100644 index 00000000..2ecc80e5 --- /dev/null +++ b/pkg/adapter/orm/db_alias.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "time" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DriverType database driver constant int. +type DriverType orm.DriverType + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL = orm.DRMySQL + DRSqlite = orm.DRSqlite // sqlite + DROracle = orm.DROracle // oracle + DRPostgres = orm.DRPostgres // pgsql + DRTiDB = orm.DRTiDB // TiDB +) + +type DB orm.DB + +func (d *DB) Begin() (*sql.Tx, error) { + return (*orm.DB)(d).Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return (*orm.DB)(d).BeginTx(ctx, opts) +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return (*orm.DB)(d).Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return (*orm.DB)(d).PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).Exec(query, args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).ExecContext(ctx, query, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).Query(query, args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).QueryContext(ctx, query, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRow(query, args) +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRowContext(ctx, query, args...) +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + return orm.AddAliasWthDB(aliasName, driverName, db) +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + opts := make([]utils.KV, 0, 2) + if len(params) > 0 { + opts = append(opts, hints.MaxIdleConnections(params[0])) + } + + if len(params) > 1 { + opts = append(opts, hints.MaxOpenConnections(params[1])) + } + return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + return orm.RegisterDriver(driverName, orm.DriverType(typ)) +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + return orm.SetDataBaseTZ(aliasName, tz) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + orm.SetMaxIdleConns(aliasName, maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + orm.SetMaxOpenConns(aliasName, maxOpenConns) +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + return orm.GetDB(aliasNames...) +} diff --git a/pkg/adapter/orm/models.go b/pkg/adapter/orm/models.go new file mode 100644 index 00000000..3215f5b5 --- /dev/null +++ b/pkg/adapter/orm/models.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + orm.ResetModelCache() +} diff --git a/pkg/adapter/orm/models_boot.go b/pkg/adapter/orm/models_boot.go new file mode 100644 index 00000000..8888ef65 --- /dev/null +++ b/pkg/adapter/orm/models_boot.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + orm.RegisterModel(models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + orm.RegisterModelWithPrefix(prefix, models) +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + orm.RegisterModelWithSuffix(suffix, models...) +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + orm.BootStrap() +} diff --git a/pkg/adapter/orm/models_fields.go b/pkg/adapter/orm/models_fields.go new file mode 100644 index 00000000..666a97dc --- /dev/null +++ b/pkg/adapter/orm/models_fields.go @@ -0,0 +1,625 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Define the Type enum +const ( + TypeBooleanField = orm.TypeBooleanField + TypeVarCharField = orm.TypeVarCharField + TypeCharField = orm.TypeCharField + TypeTextField = orm.TypeTextField + TypeTimeField = orm.TypeTimeField + TypeDateField = orm.TypeDateField + TypeDateTimeField = orm.TypeDateTimeField + TypeBitField = orm.TypeBitField + TypeSmallIntegerField = orm.TypeSmallIntegerField + TypeIntegerField = orm.TypeIntegerField + TypeBigIntegerField = orm.TypeBigIntegerField + TypePositiveBitField = orm.TypePositiveBitField + TypePositiveSmallIntegerField = orm.TypePositiveSmallIntegerField + TypePositiveIntegerField = orm.TypePositiveIntegerField + TypePositiveBigIntegerField = orm.TypePositiveBigIntegerField + TypeFloatField = orm.TypeFloatField + TypeDecimalField = orm.TypeDecimalField + TypeJSONField = orm.TypeJSONField + TypeJsonbField = orm.TypeJsonbField + RelForeignKey = orm.RelForeignKey + RelOneToOne = orm.RelOneToOne + RelManyToMany = orm.RelManyToMany + RelReverseOne = orm.RelReverseOne + RelReverseMany = orm.RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = orm.IsIntegerField + IsPositiveIntegerField = orm.IsPositiveIntegerField + IsRelField = orm.IsRelField + IsFieldType = orm.IsFieldType +) + +// BooleanField A true/false field. +type BooleanField orm.BooleanField + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return orm.BooleanField(e).Value() +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + (*orm.BooleanField)(e).Set(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return (*orm.BooleanField)(e).String() +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return (*orm.BooleanField)(e).FieldType() +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + return (*orm.BooleanField)(e).SetRaw(value) +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return (*orm.BooleanField)(e).RawValue() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField orm.CharField + +// Value return the CharField's Value +func (e CharField) Value() string { + return orm.CharField(e).Value() +} + +// Set CharField value +func (e *CharField) Set(d string) { + (*orm.CharField)(e).Set(d) +} + +// String return the CharField +func (e *CharField) String() string { + return (*orm.CharField)(e).String() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return (*orm.CharField)(e).FieldType() +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + return (*orm.CharField)(e).SetRaw(value) +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return (*orm.CharField)(e).RawValue() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField orm.TimeField + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return orm.TimeField(e).Value() +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + (*orm.TimeField)(e).Set(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return (*orm.TimeField)(e).String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return (*orm.TimeField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + return (*orm.TimeField)(e).SetRaw(value) +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return (*orm.TimeField)(e).RawValue() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField orm.DateField + +// Value return the time.Time +func (e DateField) Value() time.Time { + return orm.DateField(e).Value() +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + (*orm.DateField)(e).Set(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return (*orm.DateField)(e).String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return (*orm.DateField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + return (*orm.DateField)(e).SetRaw(value) +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return (*orm.DateField)(e).RawValue() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField orm.DateTimeField + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return orm.DateTimeField(e).Value() +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + (*orm.DateTimeField)(e).Set(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return (*orm.DateTimeField)(e).String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return (*orm.DateTimeField)(e).FieldType() +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + return (*orm.DateTimeField)(e).SetRaw(value) +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return (*orm.DateTimeField)(e).RawValue() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField orm.FloatField + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return orm.FloatField(e).Value() +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + (*orm.FloatField)(e).Set(d) +} + +// String return the string +func (e *FloatField) String() string { + return (*orm.FloatField)(e).String() +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return (*orm.FloatField)(e).FieldType() +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + return (*orm.FloatField)(e).SetRaw(value) +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return (*orm.FloatField)(e).RawValue() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField orm.SmallIntegerField + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return orm.SmallIntegerField(e).Value() +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + (*orm.SmallIntegerField)(e).Set(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return (*orm.SmallIntegerField)(e).String() +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return (*orm.SmallIntegerField)(e).FieldType() +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + return (*orm.SmallIntegerField)(e).SetRaw(value) +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return (*orm.SmallIntegerField)(e).RawValue() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField orm.IntegerField + +// Value return the int32 +func (e IntegerField) Value() int32 { + return orm.IntegerField(e).Value() +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + (*orm.IntegerField)(e).Set(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return (*orm.IntegerField)(e).String() +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return (*orm.IntegerField)(e).FieldType() +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + return (*orm.IntegerField)(e).SetRaw(value) +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return (*orm.IntegerField)(e).RawValue() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField orm.BigIntegerField + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return orm.BigIntegerField(e).Value() +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + (*orm.BigIntegerField)(e).Set(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return (*orm.BigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return (*orm.BigIntegerField)(e).FieldType() +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + return (*orm.BigIntegerField)(e).SetRaw(value) +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return (*orm.BigIntegerField)(e).RawValue() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField orm.PositiveSmallIntegerField + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return orm.PositiveSmallIntegerField(e).Value() +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + (*orm.PositiveSmallIntegerField)(e).Set(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return (*orm.PositiveSmallIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return (*orm.PositiveSmallIntegerField)(e).FieldType() +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveSmallIntegerField)(e).SetRaw(value) +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return (*orm.PositiveSmallIntegerField)(e).RawValue() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField orm.PositiveIntegerField + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return orm.PositiveIntegerField(e).Value() +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + (*orm.PositiveIntegerField)(e).Set(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return (*orm.PositiveIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return (*orm.PositiveIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveIntegerField)(e).SetRaw(value) +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return (*orm.PositiveIntegerField)(e).RawValue() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField orm.PositiveBigIntegerField + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return orm.PositiveBigIntegerField(e).Value() +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + (*orm.PositiveBigIntegerField)(e).Set(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return (*orm.PositiveBigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return (*orm.PositiveBigIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveBigIntegerField)(e).SetRaw(value) +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return (*orm.PositiveBigIntegerField)(e).RawValue() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField orm.TextField + +// Value return TextField value +func (e TextField) Value() string { + return orm.TextField(e).Value() +} + +// Set the TextField value +func (e *TextField) Set(d string) { + (*orm.TextField)(e).Set(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return (*orm.TextField)(e).String() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return (*orm.TextField)(e).FieldType() +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + return (*orm.TextField)(e).SetRaw(value) +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return (*orm.TextField)(e).RawValue() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField orm.JSONField + +// Value return JSONField value +func (j JSONField) Value() string { + return orm.JSONField(j).Value() +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + (*orm.JSONField)(j).Set(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return (*orm.JSONField)(j).String() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return (*orm.JSONField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + return (*orm.JSONField)(j).SetRaw(value) +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return (*orm.JSONField)(j).RawValue() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField orm.JsonbField + +// Value return JsonbField value +func (j JsonbField) Value() string { + return orm.JsonbField(j).Value() +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + (*orm.JsonbField)(j).Set(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return (*orm.JsonbField)(j).String() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return (*orm.JsonbField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + return (*orm.JsonbField)(j).SetRaw(value) +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return (*orm.JsonbField)(j).RawValue() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/adapter/orm/orm.go b/pkg/adapter/orm/orm.go new file mode 100644 index 00000000..f8463ea2 --- /dev/null +++ b/pkg/adapter/orm/orm.go @@ -0,0 +1,314 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = orm.Debug + DebugLog = orm.DebugLog + DefaultRowsLimit = orm.DefaultRowsLimit + DefaultRelsDepth = orm.DefaultRelsDepth + DefaultTimeLoc = orm.DefaultTimeLoc + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +type ormer struct { + delegate orm.Ormer + txDelegate orm.TxOrmer + isTx bool +} + +var _ Ormer = new(ormer) + +// read data to model +func (o *ormer) Read(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.Read(md, cols...) + } + return o.delegate.Read(md, cols...) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *ormer) ReadForUpdate(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.ReadForUpdate(md, cols...) + } + return o.delegate.ReadForUpdate(md, cols...) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *ormer) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + if o.isTx { + return o.txDelegate.ReadOrCreate(md, col1, cols...) + } + return o.delegate.ReadOrCreate(md, col1, cols...) +} + +// insert model data to database +func (o *ormer) Insert(md interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.Insert(md) + } + return o.delegate.Insert(md) +} + +// insert some models to database +func (o *ormer) InsertMulti(bulk int, mds interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.InsertMulti(bulk, mds) + } + return o.delegate.InsertMulti(bulk, mds) +} + +// InsertOrUpdate data to database +func (o *ormer) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + if o.isTx { + return o.txDelegate.InsertOrUpdate(md, colConflitAndArgs...) + } + return o.delegate.InsertOrUpdate(md, colConflitAndArgs...) +} + +// update model to database. +// cols set the columns those want to update. +func (o *ormer) Update(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Update(md, cols...) + } + return o.delegate.Update(md, cols...) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *ormer) Delete(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Delete(md, cols...) + } + return o.delegate.Delete(md, cols...) +} + +// create a models to models queryer +func (o *ormer) QueryM2M(md interface{}, name string) QueryM2Mer { + if o.isTx { + return o.txDelegate.QueryM2M(md, name) + } + return o.delegate.QueryM2M(md, name) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *ormer) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + kvs := make([]utils.KV, 0, 4) + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + kvs = append(kvs, hints.DefaultRelDepth()) + } + } else if v, ok := arg.(int); ok { + kvs = append(kvs, hints.RelDepth(v)) + } + case 1: + kvs = append(kvs, hints.Limit(orm.ToInt64(arg))) + case 2: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + case 3: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + } + } + if o.isTx { + return o.txDelegate.LoadRelated(md, name, kvs...) + } + return o.delegate.LoadRelated(md, name, kvs...) +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *ormer) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + if o.isTx { + return o.txDelegate.QueryTable(ptrStructOrTableName) + } + return o.delegate.QueryTable(ptrStructOrTableName) +} + +// switch to another registered database driver by given name. +func (o *ormer) Using(name string) error { + if o.isTx { + return ErrTxHasBegan + } + o.delegate = orm.NewOrmUsingDB(name) + return nil +} + +// begin transaction +func (o *ormer) Begin() error { + if o.isTx { + return ErrTxHasBegan + } + return o.BeginTx(context.Background(), nil) +} + +func (o *ormer) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + txOrmer, err := o.delegate.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + o.txDelegate = txOrmer + o.isTx = true + return nil +} + +// commit transaction +func (o *ormer) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Commit() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *ormer) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Rollback() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *ormer) Raw(query string, args ...interface{}) RawSeter { + if o.isTx { + return o.txDelegate.Raw(query, args...) + } + return o.delegate.Raw(query, args...) +} + +// return current using database Driver +func (o *ormer) Driver() Driver { + if o.isTx { + return o.txDelegate.Driver() + } + return o.delegate.Driver() +} + +// return sql.DBStats for current database +func (o *ormer) DBStats() *sql.DBStats { + if o.isTx { + return o.txDelegate.DBStats() + } + return o.delegate.DBStats() +} + +// NewOrm create new orm +func NewOrm() Ormer { + o := orm.NewOrm() + return &ormer{ + delegate: o, + } +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + o, err := orm.NewOrmWithDB(driverName, aliasName, db) + if err != nil { + return nil, err + } + return &ormer{ + delegate: o, + }, nil +} diff --git a/pkg/adapter/orm/orm_conds.go b/pkg/adapter/orm/orm_conds.go new file mode 100644 index 00000000..986b4858 --- /dev/null +++ b/pkg/adapter/orm/orm_conds.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +// Condition struct. +// work for WHERE conditions. +type Condition orm.Condition + +// NewCondition return new condition struct +func NewCondition() *Condition { + return (*Condition)(orm.NewCondition()) +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + return (*Condition)((orm.Condition)(c).Raw(expr, sql)) +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).And(expr, args...)) +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).AndNot(expr, args...)) +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndCond((*orm.Condition)(cond))) +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndNotCond((*orm.Condition)(cond))) +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).Or(expr, args...)) +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).OrNot(expr, args...)) +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrCond((*orm.Condition)(cond))) +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrNotCond((*orm.Condition)(cond))) +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return (*orm.Condition)(c).IsEmpty() +} diff --git a/pkg/adapter/orm/orm_log.go b/pkg/adapter/orm/orm_log.go new file mode 100644 index 00000000..6b2b4a9b --- /dev/null +++ b/pkg/adapter/orm/orm_log.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "io" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Log implement the log.Logger +type Log orm.Log + +// costomer log func +var LogFunc = orm.LogFunc + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + return (*Log)(orm.NewLog(out)) +} diff --git a/pkg/adapter/orm/orm_queryset.go b/pkg/adapter/orm/orm_queryset.go new file mode 100644 index 00000000..5f211644 --- /dev/null +++ b/pkg/adapter/orm/orm_queryset.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// define Col operations +const ( + ColAdd = orm.ColAdd + ColMinus = orm.ColMinus + ColMultiply = orm.ColMultiply + ColExcept = orm.ColExcept + ColBitAnd = orm.ColBitAnd + ColBitRShift = orm.ColBitRShift + ColBitLShift = orm.ColBitLShift + ColBitXOR = orm.ColBitXOR + ColBitOr = orm.ColBitOr +) diff --git a/pkg/adapter/orm/qb.go b/pkg/adapter/orm/qb.go new file mode 100644 index 00000000..90b97797 --- /dev/null +++ b/pkg/adapter/orm/qb.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// QueryBuilder is the Query builder interface +type QueryBuilder orm.QueryBuilder + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + return orm.NewQueryBuilder(driver) +} diff --git a/pkg/adapter/orm/qb_mysql.go b/pkg/adapter/orm/qb_mysql.go new file mode 100644 index 00000000..9566068f --- /dev/null +++ b/pkg/adapter/orm/qb_mysql.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// CommaSpace is the separation +const CommaSpace = orm.CommaSpace + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder orm.MySQLQueryBuilder + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.MySQLQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return (*orm.MySQLQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/qb_tidb.go b/pkg/adapter/orm/qb_tidb.go new file mode 100644 index 00000000..05c91a26 --- /dev/null +++ b/pkg/adapter/orm/qb_tidb.go @@ -0,0 +1,147 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder orm.TiDBQueryBuilder + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.TiDBQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return (*orm.TiDBQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/query_setter_adapter.go b/pkg/adapter/orm/query_setter_adapter.go new file mode 100644 index 00000000..cc24ef6b --- /dev/null +++ b/pkg/adapter/orm/query_setter_adapter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +type baseQuerySetter struct { +} + +func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} diff --git a/pkg/adapter/orm/types.go b/pkg/adapter/orm/types.go new file mode 100644 index 00000000..3372e301 --- /dev/null +++ b/pkg/adapter/orm/types.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Params stores the Params +type Params orm.Params + +// ParamsList stores paramslist +type ParamsList orm.ParamsList + +// Driver define database driver +type Driver orm.Driver + +// Fielder define field info +type Fielder orm.Fielder + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter orm.Inserter + +// QuerySeter query seter +type QuerySeter orm.QuerySeter + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer orm.QueryM2Mer + +// RawPreparer raw query statement +type RawPreparer orm.RawPreparer + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter orm.RawSeter diff --git a/pkg/adapter/orm/utils.go b/pkg/adapter/orm/utils.go new file mode 100644 index 00000000..16d0e4e5 --- /dev/null +++ b/pkg/adapter/orm/utils.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo orm.StrTo + +// Set string +func (f *StrTo) Set(v string) { + (*orm.StrTo)(f).Set(v) +} + +// Clear string +func (f *StrTo) Clear() { + (*orm.StrTo)(f).Clear() +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return orm.StrTo(f).Exist() +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return orm.StrTo(f).Bool() +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + return orm.StrTo(f).Float32() +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return orm.StrTo(f).Float64() +} + +// Int string to int +func (f StrTo) Int() (int, error) { + return orm.StrTo(f).Int() +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + return orm.StrTo(f).Int8() +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + return orm.StrTo(f).Int16() +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + return orm.StrTo(f).Int32() +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + return orm.StrTo(f).Int64() +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + return orm.StrTo(f).Uint() +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + return orm.StrTo(f).Uint8() +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + return orm.StrTo(f).Uint16() +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + return orm.StrTo(f).Uint32() +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + return orm.StrTo(f).Uint64() +} + +// String string to string +func (f StrTo) String() string { + return orm.StrTo(f).String() +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/adapter/orm/utils_test.go b/pkg/adapter/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/adapter/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/client/orm/db_alias.go b/pkg/client/orm/db_alias.go index 8a5cfb10..c72f29c4 100644 --- a/pkg/client/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -400,22 +400,47 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV detectTZ(al) kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxIdleConns(al, m) - } + al.SetMaxIdleConns(value.(int)) }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxOpenConns(al, m) - } + al.SetMaxOpenConns(value.(int)) }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - if m, ok := value.(time.Duration); ok { - SetConnMaxLifetime(al, m) - } + al.SetConnMaxLifetime(value.(time.Duration)) }) return al, nil } +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxOpenConns) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxIdleConns(maxIdleConns int) { + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxOpenConns(maxOpenConns int) { + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) +} + +func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) +} + // AddAliasWthDB add a aliasName for the drivename func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) @@ -476,23 +501,6 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { return nil } -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(al *alias, maxIdleConns int) { - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(al *alias, maxOpenConns int) { - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) -} - -func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { - al.ConnMaxLifetime = lifeTime - al.DB.DB.SetConnMaxLifetime(lifeTime) -} - // GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. func GetDB(aliasNames ...string) (*sql.DB, error) { diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 634b1892..95bbcb31 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -311,9 +311,7 @@ func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (in return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) + _, fi, ind, qs := o.queryRelated(md, name) var relDepth int var limit, offset int64 @@ -377,7 +375,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s } // get QuerySeter for related models to md model -func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) From 3acda41bc7be4494c9d925b06ab954683bbc01a1 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 6 Sep 2020 15:21:07 +0800 Subject: [PATCH 151/207] Fix UT --- pkg/adapter/app.go | 3 +- pkg/adapter/cache/cache_test.go | 191 ------------------- pkg/adapter/flash.go | 2 +- pkg/adapter/metric/prometheus_test.go | 2 +- pkg/adapter/plugins/cors/cors_test.go | 253 -------------------------- pkg/adapter/router.go | 3 + pkg/adapter/utils/file_test.go | 75 -------- pkg/server/web/app.go | 8 +- pkg/server/web/filter.go | 11 +- pkg/server/web/namespace.go | 2 +- pkg/server/web/router_test.go | 28 +-- pkg/task/task_test.go | 14 +- 12 files changed, 41 insertions(+), 551 deletions(-) delete mode 100644 pkg/adapter/cache/cache_test.go delete mode 100644 pkg/adapter/plugins/cors/cors_test.go delete mode 100644 pkg/adapter/utils/file_test.go diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go index 64280a7b..c1046c79 100644 --- a/pkg/adapter/app.go +++ b/pkg/adapter/app.go @@ -255,7 +255,8 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + opts := oldToNewFilterOpts(params) return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { filter((*context2.Context)(ctx)) - }, params...)) + }, opts...)) } diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go deleted file mode 100644 index 470c0a43..00000000 --- a/pkg/adapter/cache/cache_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "sync" - "testing" - "time" -) - -func TestCacheIncr(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - //timeoutDuration := 10 * time.Second - - bm.Put("edwardhey", 0, time.Second*20) - wg := sync.WaitGroup{} - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - bm.Incr("edwardhey") - }() - } - wg.Wait() - if bm.Get("edwardhey").(int) != 10 { - t.Error("Incr err") - } -} - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go index e5e1c187..02e75ed6 100644 --- a/pkg/adapter/flash.go +++ b/pkg/adapter/flash.go @@ -28,7 +28,7 @@ func NewFlash() *FlashData { // Set message to flash func (fd *FlashData) Set(key string, msg string, args ...interface{}) { - (*web.FlashData)(fd).Set(key, msg, args) + (*web.FlashData)(fd).Set(key, msg, args...) } // Success writes success message to flash. diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go index d82a6dec..87286e02 100644 --- a/pkg/adapter/metric/prometheus_test.go +++ b/pkg/adapter/metric/prometheus_test.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/context" + "github.com/astaxie/beego/pkg/adapter/context" ) func TestPrometheusMiddleWare(t *testing.T) { diff --git a/pkg/adapter/plugins/cors/cors_test.go b/pkg/adapter/plugins/cors/cors_test.go deleted file mode 100644 index 34039143..00000000 --- a/pkg/adapter/plugins/cors/cors_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cors - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header -type HTTPHeaderGuardRecorder struct { - *httptest.ResponseRecorder - savedHeaderMap http.Header -} - -// NewRecorder return HttpHeaderGuardRecorder -func NewRecorder() *HTTPHeaderGuardRecorder { - return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} -} - -func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { - gr.ResponseRecorder.WriteHeader(code) - gr.savedHeaderMap = gr.ResponseRecorder.Header() -} - -func (gr *HTTPHeaderGuardRecorder) Header() http.Header { - if gr.savedHeaderMap != nil { - // headers were written. clone so we don't get updates - clone := make(http.Header) - for k, v := range gr.savedHeaderMap { - clone[k] = v - } - return clone - } - return gr.ResponseRecorder.Header() -} - -func Test_AllowAll(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { - t.Errorf("Allow-Origin header should be *") - } -} - -func Test_AllowRegexMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://bar.foo.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != origin { - t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) - } -} - -func Test_AllowRegexNoMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://ww.foo.com.evil.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != "" { - t.Errorf("Allow-Origin header should not exist, found %v", headerValue) - } -} - -func Test_OtherHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - ExposeHeaders: []string{"Content-Length", "Hello"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) - methodsVal := recorder.HeaderMap.Get(headerAllowMethods) - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) - maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) - - if credentialsVal != "true" { - t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) - } - - if methodsVal != "PATCH,GET" { - t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) - } - - if headersVal != "Origin,X-whatever" { - t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) - } - - if exposedHeadersVal != "Content-Length,Hello" { - t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) - } - - if maxAgeVal != "300" { - t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) - } -} - -func Test_DefaultAllowHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - if headersVal != "Origin,Accept,Content-Type,Authorization" { - t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) - } -} - -func Test_Preflight(t *testing.T) { - recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowMethods: []string{"PUT", "PATCH"}, - AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, - })) - - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - r, _ := http.NewRequest("OPTIONS", "/foo", nil) - r.Header.Add(headerRequestMethod, "PUT") - r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") - handler.ServeHTTP(recorder, r) - - headers := recorder.Header() - methodsVal := headers.Get(headerAllowMethods) - headersVal := headers.Get(headerAllowHeaders) - originVal := headers.Get(headerAllowOrigin) - - if methodsVal != "PUT,PATCH" { - t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) - } - - if !strings.Contains(headersVal, "X-whatever") { - t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) - } - - if !strings.Contains(headersVal, "x-casesensitive") { - t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) - } - - if originVal != "*" { - t.Errorf("Allow-Origin is expected to be *, found %v", originVal) - } - - if recorder.Code != http.StatusOK { - t.Errorf("Status code is expected to be 200, found %d", recorder.Code) - } -} - -func Benchmark_WithoutCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} - -func Benchmark_WithCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go index 5a36fbee..8e8d9fdb 100644 --- a/pkg/adapter/router.go +++ b/pkg/adapter/router.go @@ -249,6 +249,9 @@ func oldToNewFilterOpts(params []bool) []web.FilterOpt { opts := make([]web.FilterOpt, 0, 4) if len(params) > 0 { opts = append(opts, web.WithReturnOnOutput(params[0])) + } else { + // the default value should be true + opts = append(opts, web.WithReturnOnOutput(true)) } if len(params) > 1 { opts = append(opts, web.WithResetParams(params[1])) diff --git a/pkg/adapter/utils/file_test.go b/pkg/adapter/utils/file_test.go deleted file mode 100644 index b2644157..00000000 --- a/pkg/adapter/utils/file_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "path/filepath" - "reflect" - "testing" -) - -var noExistedFile = "/tmp/not_existed_file" - -func TestSelfPath(t *testing.T) { - path := SelfPath() - if path == "" { - t.Error("path cannot be empty") - } - t.Logf("SelfPath: %s", path) -} - -func TestSelfDir(t *testing.T) { - dir := SelfDir() - t.Logf("SelfDir: %s", dir) -} - -func TestFileExists(t *testing.T) { - if !FileExists("./file.go") { - t.Errorf("./file.go should exists, but it didn't") - } - - if FileExists(noExistedFile) { - t.Errorf("Weird, how could this file exists: %s", noExistedFile) - } -} - -func TestSearchFile(t *testing.T) { - path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) - if err != nil { - t.Error(err) - } - t.Log(path) - - _, err = SearchFile(noExistedFile, ".") - if err == nil { - t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) - } -} - -func TestGrepFile(t *testing.T) { - _, err := GrepFile("", noExistedFile) - if err == nil { - t.Error("expect file-not-existed error, but got nothing") - } - - path := filepath.Join(".", "testdata", "grepe.test") - lines, err := GrepFile(`^\s*[^#]+`, path) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(lines, []string{"hello", "world"}) { - t.Errorf("expect [hello world], but receive %v", lines) - } -} diff --git a/pkg/server/web/app.go b/pkg/server/web/app.go index ad3ff663..7511c7fe 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/app.go @@ -492,15 +492,15 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) return BeeApp } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. // the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) return BeeApp } diff --git a/pkg/server/web/filter.go b/pkg/server/web/filter.go index e10faafc..9aab48d6 100644 --- a/pkg/server/web/filter.go +++ b/pkg/server/web/filter.go @@ -45,13 +45,14 @@ type FilterRouter struct { // 2. determining whether or not params need to be reset. func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, + tree: NewTree(), + pattern: pattern, + filterFunc: filter, } - fos := &filterOpts{} + fos := &filterOpts{ + returnOnOutput: true, + } for _, o := range opts { o(fos) diff --git a/pkg/server/web/namespace.go b/pkg/server/web/namespace.go index e59f38c5..a792aa60 100644 --- a/pkg/server/web/namespace.go +++ b/pkg/server/web/namespace.go @@ -91,7 +91,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 14ad1484..33b75703 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -423,7 +423,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -436,7 +436,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -444,7 +444,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -461,7 +461,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -514,8 +514,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -542,7 +542,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -570,10 +570,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -604,7 +604,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -631,8 +631,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 9f73ce46..488729dc 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -15,6 +15,7 @@ package task import ( + "context" "errors" "fmt" "sync" @@ -25,7 +26,10 @@ import ( ) func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) err := tk.Run(nil) if err != nil { t.Fatal(err) @@ -39,9 +43,9 @@ func TestParse(t *testing.T) { func TestSpec(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) AddTask("tk1", tk1) AddTask("tk2", tk2) @@ -58,7 +62,7 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 - task := func() error { + task := func(ctx context.Context) error { cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) From 6bf01eaeca8b0e8ef4cb9c35e5159a8ae55e9401 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 7 Sep 2020 20:36:54 +0800 Subject: [PATCH 152/207] Move pr 3784 here --- pkg/client/orm/orm_raw.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/client/orm/orm_raw.go b/pkg/client/orm/orm_raw.go index c2539147..e11e97fa 100644 --- a/pkg/client/orm/orm_raw.go +++ b/pkg/client/orm/orm_raw.go @@ -330,6 +330,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + structTagMap := make(map[reflect.StructTag]map[string]string) + defer rows.Close() if rows.Next() { @@ -396,7 +398,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + // thanks @Gazeboxu. + tags := structTagMap[fe.Tag] + if tags == nil { + _, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) + structTagMap[fe.Tag] = tags + } var col string if col = tags["column"]; col == "" { col = nameStrategyMap[nameStrategy](fe.Name) From 0f50b07a20b25be899e573dbb6b50d464f9001d1 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 7 Sep 2020 21:40:20 +0800 Subject: [PATCH 153/207] allow users to ignore some table when run orm commands --- pkg/client/orm/cmd.go | 6 +++++ pkg/client/orm/models.go | 2 +- pkg/client/orm/models_utils.go | 12 ++++++++++ pkg/client/orm/models_utils_test.go | 35 +++++++++++++++++++++++++++++ pkg/client/orm/types.go | 5 +++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 pkg/client/orm/models_utils_test.go diff --git a/pkg/client/orm/cmd.go b/pkg/client/orm/cmd.go index e03fc0ee..b0661971 100644 --- a/pkg/client/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -142,6 +142,12 @@ func (d *commandSyncDb) Run() error { } for i, mi := range modelCache.allOrdered() { + + if !isApplicableTableForDB(mi.addrField, d.al.Name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) + continue + } + if tables[mi.table] { if !d.noInfo { fmt.Printf("table `%s` already exists, skip\n", mi.table) diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index a7de10f7..24f564ab 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -414,7 +414,7 @@ func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { for _, mi := range modelCache.allOrdered() { queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) } - return queries,nil + return queries, nil } //getDbCreateSQL get database scheme creation sql queries diff --git a/pkg/client/orm/models_utils.go b/pkg/client/orm/models_utils.go index 6fca59a9..950ca243 100644 --- a/pkg/client/orm/models_utils.go +++ b/pkg/client/orm/models_utils.go @@ -107,6 +107,18 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get whether the table needs to be created for the database alias +func isApplicableTableForDB(val reflect.Value, db string) bool { + fun := val.MethodByName("IsApplicableTableForDB") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) + if len(vals) > 0 && vals[0].Kind() == reflect.Bool { + return vals[0].Bool() + } + } + return true +} + // get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col diff --git a/pkg/client/orm/models_utils_test.go b/pkg/client/orm/models_utils_test.go new file mode 100644 index 00000000..0a6995b3 --- /dev/null +++ b/pkg/client/orm/models_utils_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type NotApplicableModel struct { + Id int +} + +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} + +func Test_IsApplicableTableForDB(t *testing.T) { + assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) +} diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index eb34e759..b0c793b7 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -75,6 +75,11 @@ type TableUniqueI interface { TableUnique() [][]string } +// IsApplicableTableForDB if return false, we won't create table to this db +type IsApplicableTableForDB interface { + IsApplicableTableForDB(db string) bool +} + // Driver define database driver type Driver interface { Name() string From f580a714d5748d86d2c2ad6915030253162c2aa5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 8 Sep 2020 20:47:39 +0800 Subject: [PATCH 154/207] Optimize orm by using BDOption rather than hints --- pkg/adapter/orm/db_alias.go | 8 ++- pkg/client/orm/db_alias.go | 77 +++++++++++++++++---------- pkg/client/orm/db_alias_test.go | 16 +++--- pkg/client/orm/hints/db_hints.go | 30 +---------- pkg/client/orm/hints/db_hints_test.go | 28 ---------- pkg/client/orm/models_test.go | 4 +- pkg/client/orm/orm.go | 2 +- 7 files changed, 63 insertions(+), 102 deletions(-) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go index 2ecc80e5..b1f1a724 100644 --- a/pkg/adapter/orm/db_alias.go +++ b/pkg/adapter/orm/db_alias.go @@ -20,8 +20,6 @@ import ( "time" "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // DriverType database driver constant int. @@ -86,13 +84,13 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - opts := make([]utils.KV, 0, 2) + opts := make([]orm.DBOption, 0, 2) if len(params) > 0 { - opts = append(opts, hints.MaxIdleConnections(params[0])) + opts = append(opts, orm.MaxIdleConnections(params[0])) } if len(params) > 1 { - opts = append(opts, hints.MaxOpenConnections(params[1])) + opts = append(opts, orm.MaxOpenConnections(params[1])) } return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) } diff --git a/pkg/client/orm/db_alias.go b/pkg/client/orm/db_alias.go index c72f29c4..29e0904c 100644 --- a/pkg/client/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -21,9 +21,6 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" - lru "github.com/hashicorp/golang-lru" ) @@ -278,6 +275,7 @@ type alias struct { MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration + StmtCacheSize int DB *DB DbBaser dbBaser TZ *time.Location @@ -340,7 +338,7 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) if _, ok := dataBaseCache.get(aliasName); ok { return nil, existErr @@ -358,32 +356,35 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) { - kvs := utils.NewKVs(params...) +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { + + al := &alias{} + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + } + + for _, p := range params { + p(al) + } var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if al.StmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize) if errC != nil { return nil, errC } else { stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize + stmtCacheSize = al.StmtCacheSize } } - al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } + al.DB.stmtDecorators = stmtCache + al.DB.stmtDecoratorsLimit = stmtCacheSize if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -399,14 +400,6 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV detectTZ(al) - kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - al.SetMaxIdleConns(value.(int)) - }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - al.SetMaxOpenConns(value.(int)) - }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - al.SetConnMaxLifetime(value.(time.Duration)) - }) - return al, nil } @@ -442,13 +435,13 @@ func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error { +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...utils.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { var ( err error db *sql.DB @@ -561,3 +554,33 @@ func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { } return cache, nil } + +type DBOption func(al *alias) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(maxIdleConn int) DBOption { + return func(al *alias) { + al.SetMaxIdleConns(maxIdleConn) + } +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(maxOpenConn int) DBOption { + return func(al *alias) { + al.SetMaxOpenConns(maxOpenConn) + } +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) DBOption { + return func(al *alias) { + al.SetConnMaxLifetime(v) + } +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) DBOption { + return func(al *alias) { + al.StmtCacheSize = v + } +} diff --git a/pkg/client/orm/db_alias_test.go b/pkg/client/orm/db_alias_test.go index 0043ba76..6275cb2a 100644 --- a/pkg/client/orm/db_alias_test.go +++ b/pkg/client/orm/db_alias_test.go @@ -18,16 +18,14 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/stretchr/testify/assert" ) func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - hints.MaxIdleConnections(20), - hints.MaxOpenConnections(300), - hints.ConnMaxLifetime(time.Minute)) + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -39,7 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -49,7 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -59,7 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -69,7 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/client/orm/hints/db_hints.go b/pkg/client/orm/hints/db_hints.go index 4d199312..7340bd07 100644 --- a/pkg/client/orm/hints/db_hints.go +++ b/pkg/client/orm/hints/db_hints.go @@ -15,20 +15,12 @@ package hints import ( - "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) const ( - //db level - KeyMaxIdleConnections = iota - KeyMaxOpenConnections - KeyConnMaxLifetime - KeyMaxStmtCacheSize - //query level - KeyForceIndex + KeyForceIndex = iota KeyUseIndex KeyIgnoreIndex KeyForUpdate @@ -57,26 +49,6 @@ func (s *Hint) GetValue() interface{} { var _ utils.KV = new(Hint) -// MaxIdleConnections return a hint about MaxIdleConnections -func MaxIdleConnections(v int) *Hint { - return NewHint(KeyMaxIdleConnections, v) -} - -// MaxOpenConnections return a hint about MaxOpenConnections -func MaxOpenConnections(v int) *Hint { - return NewHint(KeyMaxOpenConnections, v) -} - -// ConnMaxLifetime return a hint about ConnMaxLifetime -func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(KeyConnMaxLifetime, v) -} - -// MaxStmtCacheSize return a hint about MaxStmtCacheSize -func MaxStmtCacheSize(v int) *Hint { - return NewHint(KeyMaxStmtCacheSize, v) -} - // ForceIndex return a hint about ForceIndex func ForceIndex(indexes ...string) *Hint { return NewHint(KeyForceIndex, indexes) diff --git a/pkg/client/orm/hints/db_hints_test.go b/pkg/client/orm/hints/db_hints_test.go index 4e962a8f..510f9f16 100644 --- a/pkg/client/orm/hints/db_hints_test.go +++ b/pkg/client/orm/hints/db_hints_test.go @@ -48,34 +48,6 @@ func TestNewHint_float(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) -} - func TestForceIndex(t *testing.T) { s := []string{`f_index1`, `f_index2`, `f_index3`} hint := ForceIndex(s...) diff --git a/pkg/client/orm/models_test.go b/pkg/client/orm/models_test.go index 81ba30df..f0044f6d 100644 --- a/pkg/client/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -22,8 +22,6 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -529,7 +527,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 95bbcb31..bfb710d1 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -601,7 +601,7 @@ func NewOrmUsingDB(aliasName string) Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...utils.KV) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { al, err := newAliasWithDb(aliasName, driverName, db, params...) if err != nil { return nil, err From 8982f5d70236f6083740c3de66ae2a58607eb260 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Wed, 9 Sep 2020 00:23:57 +0100 Subject: [PATCH 155/207] Add unit tests for custom log formatter Also moved is Colorful check to WriteMsg function to make the interface for user's using the custom logging formatting simpler. The user does not have to check if the text is colorful now, the WriteMsg function handles it. --- pkg/logs/console.go | 10 ++---- .../logformattertest/log_formatter_test.go | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 pkg/logs/logformattertest/log_formatter_test.go diff --git a/pkg/logs/console.go b/pkg/logs/console.go index 34114e4a..a3e5fb5a 100644 --- a/pkg/logs/console.go +++ b/pkg/logs/console.go @@ -58,10 +58,6 @@ type consoleWriter struct { func (c *consoleWriter) Format(lm *LogMsg) string { msg := lm.Msg - if c.Colorful { - msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) - } - h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') @@ -105,13 +101,13 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } - // fmt.Printf("Formatted: %s\n\n", c.fmtter.Format(lm)) + + msg := "" + if c.Colorful { lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } - msg := "" - if c.customFormatter != nil { msg = c.customFormatter(lm) } else { diff --git a/pkg/logs/logformattertest/log_formatter_test.go b/pkg/logs/logformattertest/log_formatter_test.go new file mode 100644 index 00000000..2d99a8e6 --- /dev/null +++ b/pkg/logs/logformattertest/log_formatter_test.go @@ -0,0 +1,36 @@ +package logformattertest + +import ( + "fmt" + "testing" + + "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/logs" +) + +func customFormatter(lm *logs.LogMsg) string { + return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) +} + +func globalFormatter(lm *logs.LogMsg) string { + return fmt.Sprintf("[GLOBAL] %s", lm.Msg) +} + +func TestCustomLoggingFormatter(t *testing.T) { + // beego.BConfig.Log.AccessLogs = true + + logs.SetLoggerWithOpts("console", []string{`{"color":true}`}, common.SimpleKV{Key: "formatter", Value: customFormatter}) + + // Message will be formatted by the customFormatter with colorful text set to true + logs.Informational("Test message") +} + +func TestGlobalLoggingFormatter(t *testing.T) { + logs.SetGlobalFormatter(globalFormatter) + + logs.SetLogger("console", `{"color":true}`) + + // Message will be formatted by globalFormatter + logs.Informational("Test message") + +} From 00e44952ffda640ee6f1918522dfa2f6939ef939 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Wed, 9 Sep 2020 19:04:34 +0800 Subject: [PATCH 156/207] optimize modelCache --- pkg/client/orm/filter_orm_decorator.go | 8 ++++---- pkg/client/orm/models.go | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/client/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go index 5a49e395..3271c520 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "errors" "reflect" "time" @@ -32,7 +33,6 @@ var _ TxOrmer = new(filterOrmDecorator) type filterOrmDecorator struct { ormer - modelCacheHandler TxBeginner TxCommitter @@ -44,15 +44,15 @@ type filterOrmDecorator struct { } func (f *filterOrmDecorator) RegisterModels(models ...interface{}) (err error) { - return f.modelCacheHandler.RegisterModels(models...) + return errors.New(`not callable`) } func (f *filterOrmDecorator) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return f.modelCacheHandler.RegisterModelsWithPrefix(prefix, models...) + return errors.New(`not callable`) } func (f *filterOrmDecorator) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return f.modelCacheHandler.RegisterModelsWithSuffix(suffix, models...) + return errors.New(`not callable`) } func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index 55ba5a73..b38ea9e5 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -349,7 +349,7 @@ end: fmt.Println(err) debug.PrintStack() } - modelCache.done = true + mc.done = true return } @@ -432,14 +432,14 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m //getDbDropSQL get database scheme drop sql queries func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { - if len(modelCache.cache) == 0 { + if len(mc.cache) == 0 { err = errors.New("no Model found, need register your model") return } Q := al.DbBaser.TableQuote() - for _, mi := range modelCache.allOrdered() { + for _, mi := range mc.allOrdered() { queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) } return queries,nil @@ -447,7 +447,7 @@ func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { //getDbCreateSQL get database scheme creation sql queries func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { - if len(modelCache.cache) == 0 { + if len(mc.cache) == 0 { err = errors.New("no Model found, need register your model") return } @@ -458,7 +458,7 @@ func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes tableIndexes = make(map[string][]dbIndex) - for _, mi := range modelCache.allOrdered() { + for _, mi := range mc.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) From 63cd8e4e15de50618bf0da81e59c0789864e4975 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 11 Sep 2020 21:10:12 +0800 Subject: [PATCH 157/207] refactor log module --- pkg/infrastructure/logs/alils/alils.go | 56 ++--- pkg/infrastructure/logs/conn.go | 52 ++--- pkg/infrastructure/logs/console.go | 63 +++-- pkg/infrastructure/logs/es/es.go | 65 +++--- pkg/infrastructure/logs/file.go | 65 +++--- pkg/infrastructure/logs/file_test.go | 5 + pkg/infrastructure/logs/formatter.go | 34 +++ pkg/infrastructure/logs/jianliao.go | 58 ++--- pkg/infrastructure/logs/log.go | 215 ++++++------------ pkg/infrastructure/logs/log_formatter_test.go | 35 --- pkg/infrastructure/logs/multifile.go | 43 ++-- pkg/infrastructure/logs/slack.go | 43 ++-- pkg/infrastructure/logs/smtp.go | 34 +-- 13 files changed, 346 insertions(+), 422 deletions(-) create mode 100644 pkg/infrastructure/logs/formatter.go delete mode 100644 pkg/infrastructure/logs/log_formatter_test.go diff --git a/pkg/infrastructure/logs/alils/alils.go b/pkg/infrastructure/logs/alils/alils.go index 03e97045..0689aae0 100644 --- a/pkg/infrastructure/logs/alils/alils.go +++ b/pkg/infrastructure/logs/alils/alils.go @@ -2,12 +2,14 @@ package alils import ( "encoding/json" + "fmt" "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/gogo/protobuf/proto" + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/logs" ) const ( @@ -28,40 +30,35 @@ type Config struct { Source string `json:"source"` Level int `json:"level"` FlushWhen int `json:"flush_when"` + Formatter string `json:"formatter"` } // aliLSWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - customFormatter func(*logs.LogMsg) string + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex Config + formatter logs.LogFormatter } // NewAliLS creates a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace + alils.formatter = alils return alils } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := logs.GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter - } +func (c *aliLSWriter) Init(config string) error { + err := json.Unmarshal([]byte(config), c) + if err != nil { + return err } - json.Unmarshal([]byte(jsonConfig), c) if c.FlushWhen > CacheSize { c.FlushWhen = CacheSize @@ -110,11 +107,23 @@ func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error { c.lock = &sync.Mutex{} + if len(c.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return nil } func (c *aliLSWriter) Format(lm *logs.LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (c *aliLSWriter) SetFormatter(f logs.LogFormatter) { + c.formatter = f } // WriteMsg writes a message in connection. @@ -145,11 +154,7 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { lg = c.group[0] } - if c.customFormatter != nil { - content = c.customFormatter(lm) - } else { - content = c.Format(lm) - } + content = c.formatter.Format(lm) c1 := &LogContent{ Key: proto.String("msg"), @@ -170,7 +175,6 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { if len(lg.Logs) >= c.FlushWhen { c.flush(lg) } - return nil } diff --git a/pkg/infrastructure/logs/conn.go b/pkg/infrastructure/logs/conn.go index f7d44d7f..1fd71be7 100644 --- a/pkg/infrastructure/logs/conn.go +++ b/pkg/infrastructure/logs/conn.go @@ -16,51 +16,55 @@ package logs import ( "encoding/json" + "fmt" "io" "net" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // connWriter implements LoggerInterface. // Writes messages in keep-live tcp connection. type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - customFormatter func(*LogMsg) string - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` + lg *logWriter + innerWriter io.WriteCloser + formatter LogFormatter + Formatter string `json:"formatter"` + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` } // NewConn creates new ConnWrite returning as LoggerInterface. func NewConn() Logger { conn := new(connWriter) conn.Level = LevelTrace + conn.formatter = conn return conn } func (c *connWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() } // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter +func (c *connWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) } + c.formatter = fmtr } + return res +} - return json.Unmarshal([]byte(jsonConfig), c) +func (c *connWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // WriteMsg writes message in connection. @@ -80,13 +84,7 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { defer c.innerWriter.Close() } - msg := "" - if c.customFormatter != nil { - msg = c.customFormatter(lm) - } else { - msg = c.Format(lm) - - } + msg := c.formatter.Format(lm) _, err := c.lg.writeln(msg) if err != nil { diff --git a/pkg/infrastructure/logs/console.go b/pkg/infrastructure/logs/console.go index 802d79f5..f99ef11b 100644 --- a/pkg/infrastructure/logs/console.go +++ b/pkg/infrastructure/logs/console.go @@ -16,11 +16,11 @@ package logs import ( "encoding/json" + "fmt" "os" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" - + "github.com/pkg/errors" "github.com/shiena/ansicolor" ) @@ -49,20 +49,25 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - customFormatter func(*LogMsg) string - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + formatter LogFormatter + Formatter string `json:"formatter"` + Level int `json:"level"` + Colorful bool `json:"color"` // this filed is useful only when system's terminal supports color } func (c *consoleWriter) Format(lm *LogMsg) string { - msg := lm.Msg - + msg := lm.OldStyleFormat() + if c.Colorful { + msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') - return string(bytes) +} +func (c *consoleWriter) SetFormatter(f LogFormatter) { + c.formatter = f } // NewConsole creates ConsoleWriter returning as LoggerInterface. @@ -72,28 +77,27 @@ func NewConsole() Logger { Level: LevelDebug, Colorful: true, } + cw.formatter = cw return cw } // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string, opts ...utils.KV) error { +func (c *consoleWriter) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - c.customFormatter = formatter - } - } - - if len(jsonConfig) == 0 { + if len(config) == 0 { return nil } - return json.Unmarshal([]byte(jsonConfig), c) + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", c.Formatter)) + } + c.formatter = fmtr + } + return res } // WriteMsg writes message in console. @@ -101,20 +105,7 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { if lm.Level > c.Level { return nil } - - msg := "" - - if c.Colorful { - lm.Msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) - } - - if c.customFormatter != nil { - msg = c.customFormatter(lm) - } else { - msg = c.Format(lm) - - } - + msg := c.formatter.Format(lm) c.lg.writeln(msg) return nil } diff --git a/pkg/infrastructure/logs/es/es.go b/pkg/infrastructure/logs/es/es.go index 857a1a34..438a6da6 100644 --- a/pkg/infrastructure/logs/es/es.go +++ b/pkg/infrastructure/logs/es/es.go @@ -13,7 +13,6 @@ import ( "github.com/elastic/go-elasticsearch/v6/esapi" "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // NewES returns a LoggerInterface @@ -32,29 +31,34 @@ func NewES() logs.Logger { // import _ "github.com/astaxie/beego/logs/es" type esLogger struct { *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` - customFormatter func(*logs.LogMsg) string + DSN string `json:"dsn"` + Level int `json:"level"` + formatter logs.LogFormatter + Formatter string `json:"formatter"` } func (el *esLogger) Format(lm *logs.LogMsg) string { - return lm.Msg + + msg := lm.OldStyleFormat() + idx := LogDocument{ + Timestamp: lm.When.Format(time.RFC3339), + Msg: msg, + } + body, err := json.Marshal(idx) + if err != nil { + return msg + } + return string(body) +} + +func (el *esLogger) SetFormatter(f logs.LogFormatter) { + el.formatter = f } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error { +func (el *esLogger) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := logs.GetFormatter(elem) - if err != nil { - return err - } - el.customFormatter = formatter - } - } - - err := json.Unmarshal([]byte(jsonConfig), el) + err := json.Unmarshal([]byte(config), el) if err != nil { return err } @@ -73,6 +77,13 @@ func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error { } el.Client = conn } + if len(el.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(el.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", el.Formatter)) + } + el.formatter = fmtr + } return nil } @@ -82,28 +93,14 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { return nil } - msg := "" - if el.customFormatter != nil { - msg = el.customFormatter(lm) - } else { - msg = el.Format(lm) - } + msg := el.formatter.Format(lm) - idx := LogDocument{ - Timestamp: lm.When.Format(time.RFC3339), - Msg: msg, - } - - body, err := json.Marshal(idx) - if err != nil { - return err - } req := esapi.IndexRequest{ Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()), DocumentType: "logs", - Body: strings.NewReader(string(body)), + Body: strings.NewReader(msg), } - _, err = req.Do(context.Background(), el.Client) + _, err := req.Do(context.Background(), el.Client) return err } diff --git a/pkg/infrastructure/logs/file.go b/pkg/infrastructure/logs/file.go index 0c96918c..b01be357 100644 --- a/pkg/infrastructure/logs/file.go +++ b/pkg/infrastructure/logs/file.go @@ -27,8 +27,6 @@ import ( "strings" "sync" "time" - - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // fileLogWriter implements LoggerInterface. @@ -62,8 +60,6 @@ type fileLogWriter struct { hourlyOpenDate int hourlyOpenTime time.Time - customFormatter func(*LogMsg) string - Rotate bool `json:"rotate"` Level int `json:"level"` @@ -73,6 +69,9 @@ type fileLogWriter struct { RotatePerm string `json:"rotateperm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix + + formatter LogFormatter + Formatter string `json:"formatter"` } // newFileWriter creates a FileLogWriter returning as LoggerInterface. @@ -90,11 +89,19 @@ func newFileWriter() Logger { MaxFiles: 999, MaxSize: 1 << 28, } + w.formatter = w return w } func (w *fileLogWriter) Format(lm *LogMsg) string { - return lm.Msg + msg := lm.OldStyleFormat() + hd, _, _ := formatTimeHeader(lm.When) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) + return msg +} + +func (w *fileLogWriter) SetFormatter(f LogFormatter) { + w.formatter = f } // Init file logger with json config. @@ -108,19 +115,9 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { +func (w *fileLogWriter) Init(config string) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - w.customFormatter = formatter - } - } - - err := json.Unmarshal([]byte(jsonConfig), w) + err := json.Unmarshal([]byte(config), w) if err != nil { return err } @@ -132,6 +129,14 @@ func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { if w.suffix == "" { w.suffix = ".log" } + + if len(w.Formatter) > 0 { + fmtr, ok := GetFormatter(w.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", w.Formatter)) + } + w.formatter = fmtr + } err = w.startLogger() return err } @@ -149,13 +154,13 @@ func (w *fileLogWriter) startLogger() error { return w.initFd() } -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { +func (w *fileLogWriter) needRotateDaily(day int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Daily && day != w.dailyOpenDate) } -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { +func (w *fileLogWriter) needRotateHourly(hour int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Hourly && hour != w.hourlyOpenDate) @@ -167,31 +172,25 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { if lm.Level > w.Level { return nil } - hd, d, h := formatTimeHeader(lm.When) - msg := "" - if w.customFormatter != nil { - msg = w.customFormatter(lm) - } else { - msg = w.Format(lm) - } + _, d, h := formatTimeHeader(lm.When) - msg = fmt.Sprintf("%s %s\n", string(hd), msg) + msg := w.formatter.Format(lm) if w.Rotate { w.RLock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { w.RUnlock() w.Lock() - if w.needRotateHourly(len(lm.Msg), h) { + if w.needRotateHourly(h) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } w.Unlock() - } else if w.needRotateDaily(len(lm.Msg), d) { + } else if w.needRotateDaily(d) { w.RUnlock() w.Lock() - if w.needRotateDaily(len(lm.Msg), d) { + if w.needRotateDaily(d) { if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -263,7 +262,7 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { + if w.needRotateDaily(time.Now().Day()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -278,7 +277,7 @@ func (w *fileLogWriter) hourlyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { + if w.needRotateHourly(time.Now().Hour()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } diff --git a/pkg/infrastructure/logs/file_test.go b/pkg/infrastructure/logs/file_test.go index 7f2a3590..494d0a9e 100644 --- a/pkg/infrastructure/logs/file_test.go +++ b/pkg/infrastructure/logs/file_test.go @@ -268,6 +268,7 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw if daily { fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) @@ -308,6 +309,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + fw.formatter = fw + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenDate = fw.dailyOpenTime.Day() @@ -340,6 +343,8 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { Perm: "0660", RotatePerm: "0440", } + + fw.formatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/infrastructure/logs/formatter.go new file mode 100644 index 00000000..b2599f2d --- /dev/null +++ b/pkg/infrastructure/logs/formatter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +var formatterMap = make(map[string]LogFormatter, 4) + +type LogFormatter interface { + Format(lm *LogMsg) string +} + +// RegisterFormatter register an formatter. Usually you should use this to extend your custom formatter +// for example: +// RegisterFormatter("my-fmt", &MyFormatter{}) +// logs.SetFormatter(Console, `{"formatter": "my-fmt"}`) +func RegisterFormatter(name string, fmtr LogFormatter) { + formatterMap[name] = fmtr +} + +func GetFormatter(name string) (LogFormatter, bool) { + res, ok := formatterMap[name] + return res, ok +} diff --git a/pkg/infrastructure/logs/jianliao.go b/pkg/infrastructure/logs/jianliao.go index 88750125..9757a7d5 100644 --- a/pkg/infrastructure/logs/jianliao.go +++ b/pkg/infrastructure/logs/jianliao.go @@ -6,42 +6,49 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` - customFormatter func(*LogMsg) string + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` + + formatter LogFormatter + Formatter string `json:"formatter"` } // newJLWriter creates jiaoliao writer. func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} + res := &JLWriter{Level: LevelTrace} + res.formatter = res + return res } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonConfig string, opts ...utils.KV) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - s.customFormatter = formatter - } - } +func (s *JLWriter) Init(config string) error { - return json.Unmarshal([]byte(jsonConfig), s) + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + return res } func (s *JLWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (s *JLWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // WriteMsg writes message in smtp writer. @@ -51,14 +58,7 @@ func (s *JLWriter) WriteMsg(lm *LogMsg) error { return nil } - text := "" - - if s.customFormatter != nil { - text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.customFormatter(lm)) - } else { - text = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), s.Format(lm)) - - } + text := s.formatter.Format(lm) form := url.Values{} form.Add("authorName", s.AuthorName) diff --git a/pkg/infrastructure/logs/log.go b/pkg/infrastructure/logs/log.go index 2d400eba..480cecab 100644 --- a/pkg/infrastructure/logs/log.go +++ b/pkg/infrastructure/logs/log.go @@ -38,13 +38,12 @@ import ( "log" "os" "path" - "reflect" "runtime" "strings" "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // RFC5424 log message levels. @@ -87,11 +86,11 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string, opts ...utils.KV) error + Init(config string) error WriteMsg(lm *LogMsg) error - Format(lm *LogMsg) string Destroy() Flush() + SetFormatter(f LogFormatter) } var adapters = make(map[string]newLoggerFunc) @@ -118,7 +117,6 @@ type BeeLogger struct { init bool enableFuncCallDepth bool loggerFuncCallDepth int - globalFormatter func(*LogMsg) string enableFullFilePath bool asynchronous bool prefix string @@ -127,6 +125,7 @@ type BeeLogger struct { signalChan chan string wg sync.WaitGroup outputs []*nameLogger + globalFormatter string } const defaultAsyncMsgLen = 1e3 @@ -137,15 +136,15 @@ type nameLogger struct { } type LogMsg struct { - Level int - Msg string - When time.Time - FilePath string - LineNumber int -} - -type LogFormatter interface { - Format(lm *LogMsg) string + Level int + Msg string + When time.Time + FilePath string + LineNumber int + Args []interface{} + Prefix string + enableFullFilePath bool + enableFuncCallDepth bool } var logMsgPool *sync.Pool @@ -188,8 +187,25 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } -func Format(lm *LogMsg) string { - return lm.Msg +// OldStyleFormat you should never invoke this +func (lm *LogMsg) OldStyleFormat() string { + msg := lm.Msg + + if len(lm.Args) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, lm.Args...) + } + + msg = lm.Prefix + " " + msg + + if lm.enableFuncCallDepth { + if !lm.enableFullFilePath { + _, lm.FilePath = path.Split(lm.FilePath) + } + msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, msg) + } + + msg = levelPrefix[lm.Level] + " " + msg + return msg } // SetLogger provides a given logger adapter into BeeLogger with config string. @@ -208,16 +224,18 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() - var err error // Global formatter overrides the default set formatter - // but not adapter specific formatters set with logs.SetLoggerWithOpts() - if bl.globalFormatter != nil { - err = lg.Init(config, &utils.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) - } else { - err = lg.Init(config) + if len(bl.globalFormatter) > 0 { + fmtr, ok := GetFormatter(bl.globalFormatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", bl.globalFormatter)) + } + lg.SetFormatter(fmtr) } + err := lg.Init(config) + if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -287,46 +305,34 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { return 0, err } -func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { +func (bl *BeeLogger) writeMsg(lm *LogMsg) error { if !bl.init { bl.lock.Lock() bl.setLogger(AdapterConsole) bl.lock.Unlock() } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - - lm.Msg = bl.prefix + " " + lm.Msg - var ( file string line int ok bool ) - if bl.enableFuncCallDepth { - _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - - if !bl.enableFullFilePath { - _, file = path.Split(file) - } - lm.FilePath = file - lm.LineNumber = line - lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg) + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 } + lm.FilePath = file + lm.LineNumber = line + + lm.enableFullFilePath = bl.enableFullFilePath + lm.enableFuncCallDepth = bl.enableFuncCallDepth // set level info in front of filename info if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly lm.Level = LevelEmergency - } else { - lm.Msg = levelPrefix[lm.Level] + " " + lm.Msg } if bl.asynchronous { @@ -334,6 +340,10 @@ func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { logM.Level = lm.Level logM.Msg = lm.Msg logM.When = lm.When + logM.Args = lm.Args + logM.FilePath = lm.FilePath + logM.LineNumber = lm.LineNumber + logM.Prefix = lm.Prefix if bl.outputs != nil { bl.msgChan <- lm } else { @@ -404,84 +414,14 @@ func (bl *BeeLogger) startLogger() { } } -// Get the formatter from the opts common.SimpleKV structure -// Looks for a key: "formatter" with value: func(*LogMsg) string -func GetFormatter(opts utils.KV) (func(*LogMsg) string, error) { - if strings.ToLower(opts.GetKey().(string)) == "formatter" { - formatterInterface := reflect.ValueOf(opts.GetValue()).Interface() - formatterFunc := formatterInterface.(func(*LogMsg) string) - return formatterFunc, nil - } - - return nil, fmt.Errorf("no \"formatter\" key given in simpleKV") -} - -// SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON -// such as: {"interval":360} -func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { - config := append(configs, "{}")[0] - for _, l := range bl.outputs { - if l.name == adapterName { - return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) - } - } - - logAdapter, ok := adapters[adapterName] - if !ok { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - - if opts.GetKey() == nil { - return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName) - } - - lg := logAdapter() - err := lg.Init(config, opts) - if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) - return err - } - - bl.outputs = append(bl.outputs, &nameLogger{ - name: adapterName, - Logger: lg, - }) - - return nil -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - if !bl.init { - bl.outputs = []*nameLogger{} - bl.init = true - } - return bl.setLoggerWithOpts(adapterName, opts, configs...) -} - -// SetLoggerWIthOpts sets a given log adapter with a custom log adapter. -// Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} -// where FormatFunc has the signature func(*LogMsg) string -// func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { -func SetLoggerWithOpts(adapter string, config []string, opts utils.KV) error { - err := beeLogger.SetLoggerWithOpts(adapter, opts, config...) - if err != nil { - log.Fatal(err) - } - return nil - -} - -func (bl *BeeLogger) setGlobalFormatter(fmtter func(*LogMsg) string) error { +func (bl *BeeLogger) setGlobalFormatter(fmtter string) error { bl.globalFormatter = fmtter return nil } // SetGlobalFormatter sets the global formatter for all log adapters -// This overrides and other individually set adapter -func SetGlobalFormatter(fmtter func(*LogMsg) string) error { +// don't forget to register the formatter by invoking RegisterFormatter +func SetGlobalFormatter(fmtter string) error { return beeLogger.setGlobalFormatter(fmtter) } @@ -513,11 +453,8 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { Level: LevelAlert, Msg: format, When: time.Now(), + Args: v, } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) - } - bl.writeMsg(lm) } @@ -530,9 +467,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { Level: LevelCritical, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -547,9 +482,7 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { Level: LevelError, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -564,9 +497,7 @@ func (bl *BeeLogger) Warning(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -581,9 +512,7 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { Level: LevelNotice, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -598,9 +527,7 @@ func (bl *BeeLogger) Informational(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -615,9 +542,7 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -633,9 +558,7 @@ func (bl *BeeLogger) Warn(format string, v ...interface{}) { Level: LevelWarn, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -651,9 +574,7 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { Level: LevelInfo, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) @@ -669,9 +590,7 @@ func (bl *BeeLogger) Trace(format string, v ...interface{}) { Level: LevelDebug, Msg: format, When: time.Now(), - } - if len(v) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, v...) + Args: v, } bl.writeMsg(lm) diff --git a/pkg/infrastructure/logs/log_formatter_test.go b/pkg/infrastructure/logs/log_formatter_test.go deleted file mode 100644 index 73281cf6..00000000 --- a/pkg/infrastructure/logs/log_formatter_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package logs - -import ( - "fmt" - "testing" - - "github.com/astaxie/beego/pkg/infrastructure/utils" -) - -func customFormatter(lm *LogMsg) string { - return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) -} - -func globalFormatter(lm *LogMsg) string { - return fmt.Sprintf("[GLOBAL] %s", lm.Msg) -} - -func TestCustomLoggingFormatter(t *testing.T) { - // beego.BConfig.Log.AccessLogs = true - - SetLoggerWithOpts("console", []string{`{"color":true}`}, &utils.SimpleKV{Key: "formatter", Value: customFormatter}) - - // Message will be formatted by the customFormatter with colorful text set to true - Informational("Test message") -} - -func TestGlobalLoggingFormatter(t *testing.T) { - SetGlobalFormatter(globalFormatter) - - SetLogger("console", `{"color":true}`) - - // Message will be formatted by globalFormatter - Informational("Test message") - -} diff --git a/pkg/infrastructure/logs/multifile.go b/pkg/infrastructure/logs/multifile.go index bf589b91..79178211 100644 --- a/pkg/infrastructure/logs/multifile.go +++ b/pkg/infrastructure/logs/multifile.go @@ -16,8 +16,6 @@ package logs import ( "encoding/json" - - "github.com/astaxie/beego/pkg/infrastructure/utils" ) // A filesLogWriter manages several fileLogWriter @@ -26,10 +24,9 @@ import ( // and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log // the rotate attribute also acts like fileLogWriter type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` - customFormatter func(*LogMsg) string + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` } var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} @@ -47,30 +44,27 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - f.customFormatter = formatter - } - } +func (f *multiFileLogWriter) Init(config string) error { writer := newFileWriter().(*fileLogWriter) - err := writer.Init(jsonConfig) + err := writer.Init(config) if err != nil { return err } f.fullLogWriter = writer f.writers[LevelDebug+1] = writer - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(jsonConfig), f) + // unmarshal "separate" field to f.Separate + err = json.Unmarshal([]byte(config), f) + if err != nil { + return err + } jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(jsonConfig), &jsonMap) + err = json.Unmarshal([]byte(config), &jsonMap) + if err != nil { + return err + } for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { @@ -91,7 +85,11 @@ func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { } func (f *multiFileLogWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() +} + +func (f *multiFileLogWriter) SetFormatter(fmt LogFormatter) { + f.fullLogWriter.SetFormatter(f) } func (f *multiFileLogWriter) Destroy() { @@ -126,7 +124,8 @@ func (f *multiFileLogWriter) Flush() { // newFilesWriter create a FileLogWriter returning as LoggerInterface. func newFilesWriter() Logger { - return &multiFileLogWriter{} + res := &multiFileLogWriter{} + return res } func init() { diff --git a/pkg/infrastructure/logs/slack.go b/pkg/infrastructure/logs/slack.go index d56b9acd..b6e2f170 100644 --- a/pkg/infrastructure/logs/slack.go +++ b/pkg/infrastructure/logs/slack.go @@ -6,35 +6,46 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` - UseCustomFormatter bool - CustomFormatter func(*LogMsg) string + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` + formatter LogFormatter + Formatter string `json:"formatter"` } // newSLACKWriter creates jiaoliao writer. func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} + res := &SLACKWriter{Level: LevelTrace} + res.formatter = res + return res } func (s *SLACKWriter) Format(lm *LogMsg) string { - return lm.Msg + text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), lm.OldStyleFormat()) + return text +} + +func (s *SLACKWriter) SetFormatter(f LogFormatter) { + s.formatter = f } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonConfig string, opts ...utils.KV) error { - // if elem != nil { - // s.UseCustomFormatter = true - // s.CustomFormatter = elem - // } - // } +func (s *SLACKWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) - return json.Unmarshal([]byte(jsonConfig), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) + } + s.formatter = fmtr + } + + return res } // WriteMsg write message in smtp writer. @@ -44,10 +55,8 @@ func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { return nil } msg := s.Format(lm) - text := fmt.Sprintf("{\"text\": \"%s %s\"}", lm.When.Format("2006-01-02 15:04:05"), msg) - form := url.Values{} - form.Add("payload", text) + form.Add("payload", msg) resp, err := http.PostForm(s.WebhookURL, form) if err != nil { diff --git a/pkg/infrastructure/logs/smtp.go b/pkg/infrastructure/logs/smtp.go index 904a89df..40891a7c 100644 --- a/pkg/infrastructure/logs/smtp.go +++ b/pkg/infrastructure/logs/smtp.go @@ -22,7 +22,7 @@ import ( "net/smtp" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/pkg/errors" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -34,12 +34,15 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` - customFormatter func(*LogMsg) string + formatter LogFormatter + Formatter string `json:"formatter"` } // NewSMTPWriter creates the smtp writer. func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} + res := &SMTPWriter{Level: LevelTrace} + res.formatter = res + return res } // Init smtp writer with json config. @@ -53,19 +56,16 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonConfig string, opts ...utils.KV) error { - - for _, elem := range opts { - if elem.GetKey() == "formatter" { - formatter, err := GetFormatter(elem) - if err != nil { - return err - } - s.customFormatter = formatter +func (s *SMTPWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return errors.New(fmt.Sprintf("the formatter with name: %s not found", s.Formatter)) } + s.formatter = fmtr } - - return json.Unmarshal([]byte(jsonConfig), s) + return res } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -80,6 +80,10 @@ func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { ) } +func (s *SMTPWriter) SetFormatter(f LogFormatter) { + s.formatter = f +} + func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { client, err := smtp.Dial(hostAddressWithPort) if err != nil { @@ -129,7 +133,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd } func (s *SMTPWriter) Format(lm *LogMsg) string { - return lm.Msg + return lm.OldStyleFormat() } // WriteMsg writes message in smtp writer. From b575fa1ebe076bcc21d3bb73434aea26f0043191 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 11 Sep 2020 23:48:21 +0800 Subject: [PATCH 158/207] fix 4219 --- pkg/server/web/context/input.go | 2 +- pkg/server/web/router_test.go | 17 +++++++++++++++++ pkg/server/web/tree.go | 24 ++++++++++++------------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index a6fec774..f8657f84 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -89,7 +89,7 @@ func (input *BeegoInput) URI() string { // URL returns the request url path (without query, string and fragment). func (input *BeegoInput) URL() string { - return input.Context.Request.URL.EscapedPath() + return input.Context.Request.URL.Path } // Site returns the base site url as scheme://domain type. diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 33b75703..2863da3a 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -212,6 +212,23 @@ func TestAutoExtFunc(t *testing.T) { } } +func TestEscape(t *testing.T) { + + r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/search/:keyword(.+)", func(ctx *context.Context) { + value := ctx.Input.Param(":keyword") + ctx.Output.Body([]byte(value)) + }) + handler.ServeHTTP(w, r) + str := w.Body.String() + if str != "你好" { + t.Errorf("incorrect, %s", str) + } +} + func TestRouteOk(t *testing.T) { r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) diff --git a/pkg/server/web/tree.go b/pkg/server/web/tree.go index 7213a0c6..55f68076 100644 --- a/pkg/server/web/tree.go +++ b/pkg/server/web/tree.go @@ -33,13 +33,13 @@ var ( // wildcard stores params // leaves store the endpoint information type Tree struct { - //prefix set for static router + // prefix set for static router prefix string - //search fix route first + // search fix route first fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard + // if set, failure to match fixrouters search then search wildcard wildcard *Tree - //if set, failure to match wildcard search + // if set, failure to match wildcard search leaves []*leafInfo } @@ -69,13 +69,13 @@ func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg st filterTreeWithPrefix(tree, wildcards, reg) } } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -222,13 +222,13 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, t.addseg(segments[1:], route, wildcards, reg) params = params[1:] } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -393,7 +393,7 @@ type leafInfo struct { } func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + // fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path return true @@ -500,7 +500,7 @@ func splitSegment(key string) (bool, []string, string) { continue } if start { - //:id:int and :name:string + // :id:int and :name:string if v == ':' { if len(key) >= i+4 { if key[i+1:i+4] == "int" { From c6c9ad46f945990d170d9c502074f5cb1dd55705 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Wed, 2 Sep 2020 17:49:38 +0800 Subject: [PATCH 159/207] PostgresQueryBuilder --- pkg/client/orm/orm_test.go | 37 ++++++ pkg/client/orm/qb.go | 2 +- pkg/client/orm/qb_mysql.go | 58 ++++----- pkg/client/orm/qb_postgres.go | 221 ++++++++++++++++++++++++++++++++++ pkg/client/orm/qb_tidb.go | 165 +------------------------ 5 files changed, 291 insertions(+), 192 deletions(-) create mode 100644 pkg/client/orm/qb_postgres.go diff --git a/pkg/client/orm/orm_test.go b/pkg/client/orm/orm_test.go index 8c4bf55d..6a480d8c 100644 --- a/pkg/client/orm/orm_test.go +++ b/pkg/client/orm/orm_test.go @@ -2620,3 +2620,40 @@ func TestStrPkInsert(t *testing.T) { throwFailNow(t, AssertIs(vForTesting2.Value, value2)) } } + +func TestPSQueryBuilder(t *testing.T) { + // only test postgres + if dORM.Driver().Type() != 4 { + return + } + + var user User + var l []userProfile + o := NewOrm() + + qb, err := NewQueryBuilder("postgres") + if err != nil { + throwFailNow(t, err) + } + qb.Select("user.id", "user.user_name"). + From("user").Where("id = ?").OrderBy("user_name"). + Desc().Limit(1).Offset(0) + sql := qb.String() + err = o.Raw(sql, 2).QueryRow(&user) + if err != nil { + throwFailNow(t, err) + } + throwFail(t, AssertIs(user.UserName, "slene")) + + qb.Select("*"). + From("user_profile").InnerJoin("user"). + On("user_profile.id = user.id") + sql = qb.String() + num, err := o.Raw(sql).QueryRows(&l) + if err != nil { + throwFailNow(t, err) + } + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(l[0].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[0].Age, 30)) +} diff --git a/pkg/client/orm/qb.go b/pkg/client/orm/qb.go index e0655a17..c82d2255 100644 --- a/pkg/client/orm/qb.go +++ b/pkg/client/orm/qb.go @@ -52,7 +52,7 @@ func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { } else if driver == "tidb" { qb = new(TiDBQueryBuilder) } else if driver == "postgres" { - err = errors.New("postgres query builder is not supported yet") + qb = new(PostgresQueryBuilder) } else if driver == "sqlite" { err = errors.New("sqlite query builder is not supported yet") } else { diff --git a/pkg/client/orm/qb_mysql.go b/pkg/client/orm/qb_mysql.go index 23bdc9ee..19130496 100644 --- a/pkg/client/orm/qb_mysql.go +++ b/pkg/client/orm/qb_mysql.go @@ -25,144 +25,144 @@ const CommaSpace = ", " // MySQLQueryBuilder is the SQL build type MySQLQueryBuilder struct { - Tokens []string + tokens []string } // Select will join the fields func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace)) return qb } // ForUpdate add the FOR UPDATE clause func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") + qb.tokens = append(qb.tokens, "FOR UPDATE") return qb } // From join the tables func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, "FROM", strings.Join(tables, CommaSpace)) return qb } // InnerJoin INNER JOIN the table func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + qb.tokens = append(qb.tokens, "INNER JOIN", table) return qb } // LeftJoin LEFT JOIN the table func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + qb.tokens = append(qb.tokens, "LEFT JOIN", table) return qb } // RightJoin RIGHT JOIN the table func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + qb.tokens = append(qb.tokens, "RIGHT JOIN", table) return qb } // On join with on cond func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) + qb.tokens = append(qb.tokens, "ON", cond) return qb } // Where join the Where cond func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) + qb.tokens = append(qb.tokens, "WHERE", cond) return qb } // And join the and cond func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) + qb.tokens = append(qb.tokens, "AND", cond) return qb } // Or join the or cond func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) + qb.tokens = append(qb.tokens, "OR", cond) return qb } // In join the IN (vals) func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + qb.tokens = append(qb.tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") return qb } // OrderBy join the Order by fields func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace)) return qb } // Asc join the asc func (qb *MySQLQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") + qb.tokens = append(qb.tokens, "ASC") return qb } // Desc join the desc func (qb *MySQLQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") + qb.tokens = append(qb.tokens, "DESC") return qb } // Limit join the limit num func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + qb.tokens = append(qb.tokens, "LIMIT", strconv.Itoa(limit)) return qb } // Offset join the offset num func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + qb.tokens = append(qb.tokens, "OFFSET", strconv.Itoa(offset)) return qb } // GroupBy join the Group by fields func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace)) return qb } // Having join the Having cond func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) + qb.tokens = append(qb.tokens, "HAVING", cond) return qb } // Update join the update table func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, "UPDATE", strings.Join(tables, CommaSpace)) return qb } // Set join the set kv func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + qb.tokens = append(qb.tokens, "SET", strings.Join(kv, CommaSpace)) return qb } // Delete join the Delete tables func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") + qb.tokens = append(qb.tokens, "DELETE") if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, strings.Join(tables, CommaSpace)) } return qb } // InsertInto join the insert SQL func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + qb.tokens = append(qb.tokens, "INSERT INTO", table) if len(fields) != 0 { fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + qb.tokens = append(qb.tokens, "(", fieldsStr, ")") } return qb } @@ -170,7 +170,7 @@ func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBui // Values join the Values(vals) func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + qb.tokens = append(qb.tokens, "VALUES", "(", valsStr, ")") return qb } @@ -179,7 +179,9 @@ func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { return fmt.Sprintf("(%s) AS %s", sub, alias) } -// String join all Tokens +// String join all tokens func (qb *MySQLQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") + s := strings.Join(qb.tokens, " ") + qb.tokens = qb.tokens[:0] + return s } diff --git a/pkg/client/orm/qb_postgres.go b/pkg/client/orm/qb_postgres.go new file mode 100644 index 00000000..eec784df --- /dev/null +++ b/pkg/client/orm/qb_postgres.go @@ -0,0 +1,221 @@ +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +var quote string = `"` + +// PostgresQueryBuilder is the SQL build +type PostgresQueryBuilder struct { + tokens []string +} + +func processingStr(str []string) string { + s := strings.Join(str, `","`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + return s +} + +// Select will join the fields +func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder { + + var str string + n := len(fields) + + if fields[0] == "*" { + str = "*" + } else { + for i := 0; i < n; i++ { + sli := strings.Split(fields[i], ".") + s := strings.Join(sli, `"."`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + if n == 1 || i == n-1 { + str += s + } else { + str += s + "," + } + } + } + + qb.tokens = append(qb.tokens, "SELECT", str) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *PostgresQueryBuilder) ForUpdate() QueryBuilder { + qb.tokens = append(qb.tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *PostgresQueryBuilder) From(tables ...string) QueryBuilder { + str := processingStr(tables) + qb.tokens = append(qb.tokens, "FROM", str) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *PostgresQueryBuilder) InnerJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "INNER JOIN", str) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *PostgresQueryBuilder) LeftJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "LEFT JOIN", str) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *PostgresQueryBuilder) RightJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "RIGHT JOIN", str) + return qb +} + +// On join with on cond +func (qb *PostgresQueryBuilder) On(cond string) QueryBuilder { + + var str string + cond = strings.Replace(cond, " ", "", -1) + slice := strings.Split(cond, "=") + for i := 0; i < len(slice); i++ { + sli := strings.Split(slice[i], ".") + s := strings.Join(sli, `"."`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + if i == 0 { + str = s + " =" + " " + } else { + str += s + } + } + + qb.tokens = append(qb.tokens, "ON", str) + return qb +} + +// Where join the Where cond +func (qb *PostgresQueryBuilder) Where(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *PostgresQueryBuilder) And(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *PostgresQueryBuilder) Or(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *PostgresQueryBuilder) In(vals ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder { + str := processingStr(fields) + qb.tokens = append(qb.tokens, "ORDER BY", str) + return qb +} + +// Asc join the asc +func (qb *PostgresQueryBuilder) Asc() QueryBuilder { + qb.tokens = append(qb.tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *PostgresQueryBuilder) Desc() QueryBuilder { + qb.tokens = append(qb.tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *PostgresQueryBuilder) Limit(limit int) QueryBuilder { + qb.tokens = append(qb.tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *PostgresQueryBuilder) Offset(offset int) QueryBuilder { + qb.tokens = append(qb.tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder { + str := processingStr(fields) + qb.tokens = append(qb.tokens, "GROUP BY", str) + return qb +} + +// Having join the Having cond +func (qb *PostgresQueryBuilder) Having(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *PostgresQueryBuilder) Update(tables ...string) QueryBuilder { + str := processingStr(tables) + qb.tokens = append(qb.tokens, "UPDATE", str) + return qb +} + +// Set join the set kv +func (qb *PostgresQueryBuilder) Set(kv ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *PostgresQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "DELETE") + if len(tables) != 0 { + str := processingStr(tables) + qb.tokens = append(qb.tokens, str) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *PostgresQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "INSERT INTO", str) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.tokens = append(qb.tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *PostgresQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.tokens = append(qb.tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *PostgresQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all tokens +func (qb *PostgresQueryBuilder) String() string { + s := strings.Join(qb.tokens, " ") + qb.tokens = qb.tokens[:0] + return s +} diff --git a/pkg/client/orm/qb_tidb.go b/pkg/client/orm/qb_tidb.go index 87b3ae84..772edb5d 100644 --- a/pkg/client/orm/qb_tidb.go +++ b/pkg/client/orm/qb_tidb.go @@ -14,169 +14,8 @@ package orm -import ( - "fmt" - "strconv" - "strings" -) - // TiDBQueryBuilder is the SQL build type TiDBQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *TiDBQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *TiDBQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *TiDBQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") + MySQLQueryBuilder + tokens []string } From 5618df8c76dd2df016c8cc19bfb6969cc1f0f452 Mon Sep 17 00:00:00 2001 From: l00427301 Date: Mon, 14 Sep 2020 16:18:24 +0800 Subject: [PATCH 160/207] Empty field in validator.Error when label struct tag is not declared #4222 --- pkg/infrastructure/validation/validation.go | 5 ++++ .../validation/validation_test.go | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/pkg/infrastructure/validation/validation.go b/pkg/infrastructure/validation/validation.go index 190e0f0e..134e750e 100644 --- a/pkg/infrastructure/validation/validation.go +++ b/pkg/infrastructure/validation/validation.go @@ -269,6 +269,11 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result { Field := "" Label := "" parts := strings.Split(key, ".") + if len(parts) == 2 { + Field = parts[0] + Name = parts[1] + Label = Field + } if len(parts) == 3 { Field = parts[0] Name = parts[1] diff --git a/pkg/infrastructure/validation/validation_test.go b/pkg/infrastructure/validation/validation_test.go index b4b5b1b6..bca4f560 100644 --- a/pkg/infrastructure/validation/validation_test.go +++ b/pkg/infrastructure/validation/validation_test.go @@ -607,3 +607,28 @@ func TestCanSkipAlso(t *testing.T) { } } + +func TestFieldNoEmpty(t *testing.T) { + type User struct { + Name string `json:"name" valid:"Match(/^[a-zA-Z][a-zA-Z0-9._-]{0,31}$/)"` + } + u := User{ + Name: "*", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should be passed") + } + if len(valid.Errors) == 0 { + t.Fatal("validation should be passed") + } + validErr := valid.Errors[0] + if len(validErr.Field) == 0 { + t.Fatal("validation should be passed") + } +} From b7bc57c4d155f4dc41ebae59ecafc1f4a89f8021 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Wed, 16 Sep 2020 19:46:14 +0800 Subject: [PATCH 161/207] delete interface --- pkg/client/orm/do_nothing_orm.go | 12 ------------ pkg/client/orm/filter_orm_decorator.go | 13 ------------- pkg/client/orm/models.go | 23 ----------------------- pkg/client/orm/orm.go | 15 --------------- pkg/client/orm/types.go | 1 - 5 files changed, 64 deletions(-) diff --git a/pkg/client/orm/do_nothing_orm.go b/pkg/client/orm/do_nothing_orm.go index 07c7fd74..e27e7f3a 100644 --- a/pkg/client/orm/do_nothing_orm.go +++ b/pkg/client/orm/do_nothing_orm.go @@ -30,18 +30,6 @@ var _ Ormer = new(DoNothingOrm) type DoNothingOrm struct { } -func (d *DoNothingOrm) RegisterModels(models ...interface{}) (err error) { - return nil -} - -func (d *DoNothingOrm) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return nil -} - -func (d *DoNothingOrm) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return nil -} - func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { return nil } diff --git a/pkg/client/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go index 3271c520..d0c5c537 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -17,7 +17,6 @@ package orm import ( "context" "database/sql" - "errors" "reflect" "time" @@ -43,18 +42,6 @@ type filterOrmDecorator struct { txName string } -func (f *filterOrmDecorator) RegisterModels(models ...interface{}) (err error) { - return errors.New(`not callable`) -} - -func (f *filterOrmDecorator) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return errors.New(`not callable`) -} - -func (f *filterOrmDecorator) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return errors.New(`not callable`) -} - func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { res := &filterOrmDecorator{ ormer: delegate, diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index ea315c2b..19941d2e 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -36,15 +36,6 @@ var ( modelCache = NewModelCacheHandler() ) -type modelCacheHandler interface { - //RegisterModels register models without prefix or suffix - RegisterModels(models ...interface{}) (err error) - //RegisterModelsWithPrefix register models with prefix - RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) - //RegisterModelsWithSuffix register models with suffix - RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) -} - // model info collection type _modelCache struct { sync.RWMutex // only used outsite for bootStrap @@ -62,20 +53,6 @@ func NewModelCacheHandler() *_modelCache { } } -var _ modelCacheHandler = new(_modelCache) - -func (mc *_modelCache) RegisterModels(models ...interface{}) (err error) { - return mc.register(``, true, models...) -} - -func (mc *_modelCache) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return mc.register(prefix, true, models...) -} - -func (mc *_modelCache) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return mc.register(suffix, false, models...) -} - // get all model info func (mc *_modelCache) all() map[string]*modelInfo { m := make(map[string]*modelInfo, len(mc.cache)) diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index 557c788c..bfb710d1 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -496,23 +496,10 @@ func (o *ormBase) DBStats() *sql.DBStats { type orm struct { ormBase - modelCacheHandler } var _ Ormer = new(orm) -func (o *orm) RegisterModels(models ...interface{}) (err error) { - return o.modelCacheHandler.RegisterModels(models) -} - -func (o *orm) RegisterModelsWithPrefix(prefix string, models ...interface{}) (err error) { - return o.modelCacheHandler.RegisterModelsWithPrefix(prefix, models...) -} - -func (o *orm) RegisterModelsWithSuffix(suffix string, models ...interface{}) (err error) { - return o.modelCacheHandler.RegisterModelsWithSuffix(suffix, models...) -} - func (o *orm) Begin() (TxOrmer, error) { return o.BeginWithCtx(context.Background()) } @@ -633,8 +620,6 @@ func newDBWithAlias(al *alias) Ormer { o.db = al.DB } - o.modelCacheHandler = NewModelCacheHandler() - if len(globalFilterChains) > 0 { return NewFilterOrmDecorator(o, globalFilterChains...) } diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index b9c444eb..b0c793b7 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -219,7 +219,6 @@ type ormer interface { type Ormer interface { ormer - modelCacheHandler TxBeginner } From 6e638ef6c8c42b99060d1a2ab8f3b10cf22afa09 Mon Sep 17 00:00:00 2001 From: wangle <285273592@qq.com> Date: Sat, 19 Sep 2020 18:28:53 +0800 Subject: [PATCH 162/207] Provides a quick format method by PatternLogFormatter struct --- pkg/infrastructure/logs/formatter.go | 55 +++++++++++++++++++++++ pkg/infrastructure/logs/formatter_test.go | 28 ++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 pkg/infrastructure/logs/formatter_test.go diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/infrastructure/logs/formatter.go index b2599f2d..981ecdbf 100644 --- a/pkg/infrastructure/logs/formatter.go +++ b/pkg/infrastructure/logs/formatter.go @@ -14,12 +14,39 @@ package logs +import ( + "path" + "strconv" +) + var formatterMap = make(map[string]LogFormatter, 4) type LogFormatter interface { Format(lm *LogMsg) string } +// PatternLogFormatter provides a quick format method +// for example: +// tes := PatternLogFormatter{Pattern: "%F:%n|%w %t>> %m", WhenFormat: "2006-01-02"} +// RegisterFormatter("tes", tes) +// SetGlobalFormatter("tes") +type PatternLogFormatter struct { + Pattern string + WhenFormat string +} + +func (p PatternLogFormatter) getWhenFormatter() string { + s := p.WhenFormat + if s == "" { + s = "2006/01/02 15:04:05.123" // default style + } + return s +} + +func (p PatternLogFormatter) Format(lm *LogMsg) string { + return p.ToString(lm) +} + // RegisterFormatter register an formatter. Usually you should use this to extend your custom formatter // for example: // RegisterFormatter("my-fmt", &MyFormatter{}) @@ -32,3 +59,31 @@ func GetFormatter(name string) (LogFormatter, bool) { res, ok := formatterMap[name] return res, ok } + +// 'w' when, 'm' msg,'f' filename,'F' full path,'n' line number +// 'l' level number, 't' prefix of level type, 'T' full name of level type +func (p PatternLogFormatter) ToString(lm *LogMsg) string { + s := []rune(p.Pattern) + m := map[rune]string{ + 'w': lm.When.Format(p.getWhenFormatter()), + 'm': lm.Msg, + 'n': strconv.Itoa(lm.LineNumber), + 'l': strconv.Itoa(lm.Level), + 't': levelPrefix[lm.Level-1], + 'T': levelNames[lm.Level-1], + 'F': lm.FilePath, + } + _, m['f'] = path.Split(lm.FilePath) + res := "" + for i := 0; i < len(s)-1; i++ { + if s[i] == '%' { + if k, ok := m[s[i+1]]; ok { + res += k + i++ + continue + } + } + res += string(s[i]) + } + return res +} diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/infrastructure/logs/formatter_test.go new file mode 100644 index 00000000..6ae94a6a --- /dev/null +++ b/pkg/infrastructure/logs/formatter_test.go @@ -0,0 +1,28 @@ +package logs + +import ( + "strconv" + "testing" + "time" +) + +func TestPatternLogFormatter(t *testing.T) { + tes := PatternLogFormatter{ + Pattern: "%F:%n|%w%t>> %m", + WhenFormat: "2006-01-02", + } + when := time.Now() + lm := &LogMsg{ + Msg: "message", + FilePath: "/User/go/beego/main.go", + Level: LevelWarn, + LineNumber: 10, + When: when, + } + got := tes.ToString(lm) + want := lm.FilePath + ":" + strconv.Itoa(lm.LineNumber) + "|" + + when.Format(tes.WhenFormat) + levelPrefix[lm.Level-1] + ">> " + lm.Msg + if got != want { + t.Errorf("want %s, got %s", want, got) + } +} From 05c125ec2d4cb78aced77c16394dd2e2d87f802f Mon Sep 17 00:00:00 2001 From: wangle <285273592@qq.com> Date: Sat, 19 Sep 2020 20:18:09 +0800 Subject: [PATCH 163/207] change to pointer receiver --- pkg/infrastructure/logs/formatter.go | 8 ++++---- pkg/infrastructure/logs/formatter_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/infrastructure/logs/formatter.go index 981ecdbf..67500b2b 100644 --- a/pkg/infrastructure/logs/formatter.go +++ b/pkg/infrastructure/logs/formatter.go @@ -27,7 +27,7 @@ type LogFormatter interface { // PatternLogFormatter provides a quick format method // for example: -// tes := PatternLogFormatter{Pattern: "%F:%n|%w %t>> %m", WhenFormat: "2006-01-02"} +// tes := &PatternLogFormatter{Pattern: "%F:%n|%w %t>> %m", WhenFormat: "2006-01-02"} // RegisterFormatter("tes", tes) // SetGlobalFormatter("tes") type PatternLogFormatter struct { @@ -35,7 +35,7 @@ type PatternLogFormatter struct { WhenFormat string } -func (p PatternLogFormatter) getWhenFormatter() string { +func (p *PatternLogFormatter) getWhenFormatter() string { s := p.WhenFormat if s == "" { s = "2006/01/02 15:04:05.123" // default style @@ -43,7 +43,7 @@ func (p PatternLogFormatter) getWhenFormatter() string { return s } -func (p PatternLogFormatter) Format(lm *LogMsg) string { +func (p *PatternLogFormatter) Format(lm *LogMsg) string { return p.ToString(lm) } @@ -62,7 +62,7 @@ func GetFormatter(name string) (LogFormatter, bool) { // 'w' when, 'm' msg,'f' filename,'F' full path,'n' line number // 'l' level number, 't' prefix of level type, 'T' full name of level type -func (p PatternLogFormatter) ToString(lm *LogMsg) string { +func (p *PatternLogFormatter) ToString(lm *LogMsg) string { s := []rune(p.Pattern) m := map[rune]string{ 'w': lm.When.Format(p.getWhenFormatter()), diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/infrastructure/logs/formatter_test.go index 6ae94a6a..f9320b8d 100644 --- a/pkg/infrastructure/logs/formatter_test.go +++ b/pkg/infrastructure/logs/formatter_test.go @@ -7,7 +7,7 @@ import ( ) func TestPatternLogFormatter(t *testing.T) { - tes := PatternLogFormatter{ + tes := &PatternLogFormatter{ Pattern: "%F:%n|%w%t>> %m", WhenFormat: "2006-01-02", } From 67f64afa8500098ff074c58b259a0bcfb144e7b5 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 19 Sep 2020 21:45:37 +0800 Subject: [PATCH 164/207] movement for 4198 --- pkg/client/orm/db.go | 9 +++++---- pkg/client/orm/db_mysql.go | 9 +++++---- pkg/client/orm/orm_test.go | 11 +++++++++++ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pkg/client/orm/db.go b/pkg/client/orm/db.go index 820435ca..2bd1308f 100644 --- a/pkg/client/orm/db.go +++ b/pkg/client/orm/db.go @@ -38,10 +38,11 @@ var ( var ( operators = map[string]bool{ - "exact": true, - "iexact": true, - "contains": true, - "icontains": true, + "exact": true, + "iexact": true, + "strictexact": true, + "contains": true, + "icontains": true, // "regex": true, // "iregex": true, "gt": true, diff --git a/pkg/client/orm/db_mysql.go b/pkg/client/orm/db_mysql.go index f602fd0a..f674ab2b 100644 --- a/pkg/client/orm/db_mysql.go +++ b/pkg/client/orm/db_mysql.go @@ -22,10 +22,11 @@ import ( // mysql operators. var mysqlOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ?", - "contains": "LIKE BINARY ?", - "icontains": "LIKE ?", + "exact": "= ?", + "iexact": "LIKE ?", + "strictexact": "= BINARY ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", // "regex": "REGEXP BINARY ?", // "iregex": "REGEXP ?", "gt": "> ?", diff --git a/pkg/client/orm/orm_test.go b/pkg/client/orm/orm_test.go index 6a480d8c..bd92f46f 100644 --- a/pkg/client/orm/orm_test.go +++ b/pkg/client/orm/orm_test.go @@ -877,6 +877,17 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) + if IsMysql { + // Now only mysql support `strictexact` + num, err = qs.Filter("user_name__strictexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + num, err = qs.Filter("user_name__strictexact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } + num, err = qs.Filter("user_name__contains", "e").Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) From a1782cc22dde81e8d8afbb5391fe908699a9b656 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 19 Sep 2020 23:04:40 +0800 Subject: [PATCH 165/207] Add tests for log module --- .../logs/{accesslog.go => access_log.go} | 19 +++-- pkg/infrastructure/logs/access_log_test.go | 38 ++++++++++ pkg/infrastructure/logs/conn_test.go | 18 +++++ pkg/infrastructure/logs/console.go | 6 +- pkg/infrastructure/logs/console_test.go | 18 +++++ pkg/infrastructure/logs/file_test.go | 17 +++++ pkg/infrastructure/logs/formatter_test.go | 69 ++++++++++++++++++- pkg/infrastructure/logs/jianliao.go | 4 +- pkg/infrastructure/logs/jianliao_test.go | 65 +++++++++++++++++ pkg/infrastructure/logs/log.go | 34 --------- pkg/infrastructure/logs/log_msg.go | 55 +++++++++++++++ pkg/infrastructure/logs/log_msg_test.go | 44 ++++++++++++ pkg/infrastructure/logs/log_test.go | 27 ++++++++ 13 files changed, 370 insertions(+), 44 deletions(-) rename pkg/infrastructure/logs/{accesslog.go => access_log.go} (95%) create mode 100644 pkg/infrastructure/logs/access_log_test.go create mode 100644 pkg/infrastructure/logs/jianliao_test.go create mode 100644 pkg/infrastructure/logs/log_msg.go create mode 100644 pkg/infrastructure/logs/log_msg_test.go create mode 100644 pkg/infrastructure/logs/log_test.go diff --git a/pkg/infrastructure/logs/accesslog.go b/pkg/infrastructure/logs/access_log.go similarity index 95% rename from pkg/infrastructure/logs/accesslog.go rename to pkg/infrastructure/logs/access_log.go index 1be711d8..10455fe9 100644 --- a/pkg/infrastructure/logs/accesslog.go +++ b/pkg/infrastructure/logs/access_log.go @@ -63,7 +63,17 @@ func disableEscapeHTML(i interface{}) { // AccessLog - Format and print access log. func AccessLog(r *AccessLogRecord, format string) { - var msg string + msg := r.format(format) + lm := &LogMsg{ + Msg: strings.TrimSpace(msg), + When: time.Now(), + Level: levelLoggerImpl, + } + beeLogger.writeMsg(lm) +} + +func (r *AccessLogRecord) format(format string) string { + msg := "" switch format { case apacheFormat: timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") @@ -79,10 +89,5 @@ func AccessLog(r *AccessLogRecord, format string) { msg = string(jsonData) } } - lm := &LogMsg{ - Msg: strings.TrimSpace(msg), - When: time.Now(), - Level: levelLoggerImpl, - } - beeLogger.writeMsg(lm) + return msg } diff --git a/pkg/infrastructure/logs/access_log_test.go b/pkg/infrastructure/logs/access_log_test.go new file mode 100644 index 00000000..f78a00a0 --- /dev/null +++ b/pkg/infrastructure/logs/access_log_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAccessLog_format(t *testing.T) { + alc := &AccessLogRecord{ + RequestTime: time.Date(2020, 9, 19, 21, 21, 21, 11, time.UTC), + } + + res := alc.format(apacheFormat) + println(res) + assert.Equal(t, " - - [19/Sep/2020 09:21:21] \" 0 0\" 0.000000 ", res) + + res = alc.format(jsonFormat) + assert.Equal(t, + "{\"remote_addr\":\"\",\"request_time\":\"2020-09-19T21:21:21.000000011Z\",\"request_method\":\"\",\"request\":\"\",\"server_protocol\":\"\",\"host\":\"\",\"status\":0,\"body_bytes_sent\":0,\"elapsed_time\":0,\"http_referrer\":\"\",\"http_user_agent\":\"\",\"remote_user\":\"\"}\n", res) + + AccessLog(alc, jsonFormat) +} diff --git a/pkg/infrastructure/logs/conn_test.go b/pkg/infrastructure/logs/conn_test.go index bb377d41..ca9ea1c7 100644 --- a/pkg/infrastructure/logs/conn_test.go +++ b/pkg/infrastructure/logs/conn_test.go @@ -18,6 +18,9 @@ import ( "net" "os" "testing" + "time" + + "github.com/stretchr/testify/assert" ) // ConnTCPListener takes a TCP listener and accepts n TCP connections @@ -45,6 +48,7 @@ func TestConn(t *testing.T) { log.Informational("informational") } +// need to rewrite this test, it's not stable func TestReconnect(t *testing.T) { // Setup connection listener newConns := make(chan net.Conn) @@ -77,3 +81,17 @@ func TestReconnect(t *testing.T) { t.Error("Did not reconnect") } } + +func TestConnWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + cw := NewConn().(*connWriter) + res := cw.Format(lg) + assert.Equal(t, "[D] Cus Hello, world", res) +} diff --git a/pkg/infrastructure/logs/console.go b/pkg/infrastructure/logs/console.go index f99ef11b..66e2c7ea 100644 --- a/pkg/infrastructure/logs/console.go +++ b/pkg/infrastructure/logs/console.go @@ -59,7 +59,7 @@ type consoleWriter struct { func (c *consoleWriter) Format(lm *LogMsg) string { msg := lm.OldStyleFormat() if c.Colorful { - msg = strings.Replace(lm.Msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + msg = strings.Replace(msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } h, _, _ := formatTimeHeader(lm.When) bytes := append(append(h, msg...), '\n') @@ -72,6 +72,10 @@ func (c *consoleWriter) SetFormatter(f LogFormatter) { // NewConsole creates ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { + return newConsole() +} + +func newConsole() *consoleWriter { cw := &consoleWriter{ lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), Level: LevelDebug, diff --git a/pkg/infrastructure/logs/console_test.go b/pkg/infrastructure/logs/console_test.go index 4bc45f57..e345ba40 100644 --- a/pkg/infrastructure/logs/console_test.go +++ b/pkg/infrastructure/logs/console_test.go @@ -17,6 +17,8 @@ package logs import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) // Try each log level in decreasing order of priority. @@ -62,3 +64,19 @@ func TestConsoleAsync(t *testing.T) { time.Sleep(1 * time.Millisecond) } } + +func TestFormat(t *testing.T) { + log := newConsole() + lm := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + res := log.Format(lm) + assert.Equal(t, "2020/09/19 20:12:37.000 \x1b[1;44m[D]\x1b[0m Cus Hello, world\n", res) + err := log.WriteMsg(lm) + assert.Nil(t, err) +} diff --git a/pkg/infrastructure/logs/file_test.go b/pkg/infrastructure/logs/file_test.go index 494d0a9e..6612ebe6 100644 --- a/pkg/infrastructure/logs/file_test.go +++ b/pkg/infrastructure/logs/file_test.go @@ -22,6 +22,8 @@ import ( "strconv" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestFilePerm(t *testing.T) { @@ -428,3 +430,18 @@ func BenchmarkFileOnGoroutine(b *testing.B) { } os.Remove("test4.log") } + +func TestFileLogWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + + fw := newFileWriter().(*fileLogWriter) + res := fw.Format(lg) + assert.Equal(t, "2020/09/19 20:12:37.000 [D] Cus Hello, world\n", res) +} diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/infrastructure/logs/formatter_test.go index f9320b8d..7ba9237b 100644 --- a/pkg/infrastructure/logs/formatter_test.go +++ b/pkg/infrastructure/logs/formatter_test.go @@ -1,11 +1,78 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package logs import ( + "encoding/json" + "errors" "strconv" "testing" "time" + + "github.com/stretchr/testify/assert" ) +type CustomFormatter struct{} + +func (c *CustomFormatter) Format(lm *LogMsg) string { + return "hello, msg: " + lm.Msg +} + +type TestLogger struct { + Formatter string `json:"formatter"` + Expected string + formatter LogFormatter +} + +func (t *TestLogger) Init(config string) error { + er := json.Unmarshal([]byte(config), t) + t.formatter, _ = GetFormatter(t.Formatter) + return er +} + +func (t *TestLogger) WriteMsg(lm *LogMsg) error { + msg := t.formatter.Format(lm) + if msg != t.Expected { + return errors.New("not equal") + } + return nil +} + +func (t *TestLogger) Destroy() { + panic("implement me") +} + +func (t *TestLogger) Flush() { + panic("implement me") +} + +func (t *TestLogger) SetFormatter(f LogFormatter) { + panic("implement me") +} + +func TestCustomFormatter(t *testing.T) { + RegisterFormatter("custom", &CustomFormatter{}) + tl := &TestLogger{ + Expected: "hello, msg: world", + } + assert.Nil(t, tl.Init(`{"formatter": "custom"}`)) + assert.Nil(t, tl.WriteMsg(&LogMsg{ + Msg: "world", + })) +} + func TestPatternLogFormatter(t *testing.T) { tes := &PatternLogFormatter{ Pattern: "%F:%n|%w%t>> %m", @@ -25,4 +92,4 @@ func TestPatternLogFormatter(t *testing.T) { if got != want { t.Errorf("want %s, got %s", want, got) } -} +} \ No newline at end of file diff --git a/pkg/infrastructure/logs/jianliao.go b/pkg/infrastructure/logs/jianliao.go index 9757a7d5..c82a0957 100644 --- a/pkg/infrastructure/logs/jianliao.go +++ b/pkg/infrastructure/logs/jianliao.go @@ -44,7 +44,9 @@ func (s *JLWriter) Init(config string) error { } func (s *JLWriter) Format(lm *LogMsg) string { - return lm.OldStyleFormat() + msg := lm.OldStyleFormat() + msg = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), msg) + return msg } func (s *JLWriter) SetFormatter(f LogFormatter) { diff --git a/pkg/infrastructure/logs/jianliao_test.go b/pkg/infrastructure/logs/jianliao_test.go new file mode 100644 index 00000000..eb0ac9e1 --- /dev/null +++ b/pkg/infrastructure/logs/jianliao_test.go @@ -0,0 +1,65 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type TestHttpHandler struct { +} + +func (t *TestHttpHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("coming")) +} + +func TestJLWriter_WriteMsg(t *testing.T) { + // start sever + + http.Handle("/", &TestHttpHandler{}) + go http.ListenAndServe(":12124", nil) + + jl := newJLWriter() + jl.Init(`{ +"webhookurl":"http://localhost:12124/hello", +"redirecturl":"nil", +"imageurl":"a" +}`) + err := jl.WriteMsg(&LogMsg{ + Msg: "world", + }) + + jl.Flush() + jl.Destroy() + assert.Nil(t, err) +} + +func TestJLWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + jl := newJLWriter().(*JLWriter) + res := jl.Format(lg) + assert.Equal(t, "2020-09-19 20:12:37 [D] Cus Hello, world", res) +} diff --git a/pkg/infrastructure/logs/log.go b/pkg/infrastructure/logs/log.go index 480cecab..cec8d51d 100644 --- a/pkg/infrastructure/logs/log.go +++ b/pkg/infrastructure/logs/log.go @@ -37,7 +37,6 @@ import ( "fmt" "log" "os" - "path" "runtime" "strings" "sync" @@ -135,18 +134,6 @@ type nameLogger struct { name string } -type LogMsg struct { - Level int - Msg string - When time.Time - FilePath string - LineNumber int - Args []interface{} - Prefix string - enableFullFilePath bool - enableFuncCallDepth bool -} - var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. @@ -187,27 +174,6 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } -// OldStyleFormat you should never invoke this -func (lm *LogMsg) OldStyleFormat() string { - msg := lm.Msg - - if len(lm.Args) > 0 { - lm.Msg = fmt.Sprintf(lm.Msg, lm.Args...) - } - - msg = lm.Prefix + " " + msg - - if lm.enableFuncCallDepth { - if !lm.enableFullFilePath { - _, lm.FilePath = path.Split(lm.FilePath) - } - msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, msg) - } - - msg = levelPrefix[lm.Level] + " " + msg - return msg -} - // SetLogger provides a given logger adapter into BeeLogger with config string. // config must in in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { diff --git a/pkg/infrastructure/logs/log_msg.go b/pkg/infrastructure/logs/log_msg.go new file mode 100644 index 00000000..f96fa72f --- /dev/null +++ b/pkg/infrastructure/logs/log_msg.go @@ -0,0 +1,55 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "fmt" + "path" + "time" +) + +type LogMsg struct { + Level int + Msg string + When time.Time + FilePath string + LineNumber int + Args []interface{} + Prefix string + enableFullFilePath bool + enableFuncCallDepth bool +} + +// OldStyleFormat you should never invoke this +func (lm *LogMsg) OldStyleFormat() string { + msg := lm.Msg + + if len(lm.Args) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, lm.Args...) + } + + msg = lm.Prefix + " " + msg + + if lm.enableFuncCallDepth { + filePath := lm.FilePath + if !lm.enableFullFilePath { + _, filePath = path.Split(filePath) + } + msg = fmt.Sprintf("[%s:%d] %s", filePath, lm.LineNumber, msg) + } + + msg = levelPrefix[lm.Level] + " " + msg + return msg +} diff --git a/pkg/infrastructure/logs/log_msg_test.go b/pkg/infrastructure/logs/log_msg_test.go new file mode 100644 index 00000000..f213ed42 --- /dev/null +++ b/pkg/infrastructure/logs/log_msg_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLogMsg_OldStyleFormat(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + res := lg.OldStyleFormat() + assert.Equal(t, "[D] Cus Hello, world", res) + + lg.enableFuncCallDepth = true + res = lg.OldStyleFormat() + assert.Equal(t, "[D] [main.go:13] Cus Hello, world", res) + + lg.enableFullFilePath = true + + res = lg.OldStyleFormat() + assert.Equal(t, "[D] [/user/home/main.go:13] Cus Hello, world", res) +} diff --git a/pkg/infrastructure/logs/log_test.go b/pkg/infrastructure/logs/log_test.go new file mode 100644 index 00000000..66f59108 --- /dev/null +++ b/pkg/infrastructure/logs/log_test.go @@ -0,0 +1,27 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBeeLogger_DelLogger(t *testing.T) { + prefix := "My-Cus" + l := GetLogger(prefix) + assert.NotNil(t, l) +} From a3ece98cec008ed993e03f4cd75801a016aa83f0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 19 Sep 2020 23:49:41 +0800 Subject: [PATCH 166/207] Add IndexNaming interface so users can custom the index name when they use es as the logger --- pkg/infrastructure/logs/es/es.go | 7 ++-- pkg/infrastructure/logs/es/index.go | 39 +++++++++++++++++++++++ pkg/infrastructure/logs/es/index_test.go | 34 ++++++++++++++++++++ pkg/infrastructure/logs/formatter_test.go | 2 +- 4 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 pkg/infrastructure/logs/es/index.go create mode 100644 pkg/infrastructure/logs/es/index_test.go diff --git a/pkg/infrastructure/logs/es/es.go b/pkg/infrastructure/logs/es/es.go index 438a6da6..c4090eab 100644 --- a/pkg/infrastructure/logs/es/es.go +++ b/pkg/infrastructure/logs/es/es.go @@ -18,7 +18,8 @@ import ( // NewES returns a LoggerInterface func NewES() logs.Logger { cw := &esLogger{ - Level: logs.LevelDebug, + Level: logs.LevelDebug, + indexNaming: indexNaming, } return cw } @@ -35,6 +36,8 @@ type esLogger struct { Level int `json:"level"` formatter logs.LogFormatter Formatter string `json:"formatter"` + + indexNaming IndexNaming } func (el *esLogger) Format(lm *logs.LogMsg) string { @@ -96,7 +99,7 @@ func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { msg := el.formatter.Format(lm) req := esapi.IndexRequest{ - Index: fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()), + Index: indexNaming.IndexName(lm), DocumentType: "logs", Body: strings.NewReader(msg), } diff --git a/pkg/infrastructure/logs/es/index.go b/pkg/infrastructure/logs/es/index.go new file mode 100644 index 00000000..9796987e --- /dev/null +++ b/pkg/infrastructure/logs/es/index.go @@ -0,0 +1,39 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package es + +import ( + "fmt" + + "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +// IndexNaming generate the index name +type IndexNaming interface { + IndexName(lm *logs.LogMsg) string +} + +var indexNaming IndexNaming = &defaultIndexNaming{} + +// SetIndexNaming will register global IndexNaming +func SetIndexNaming(i IndexNaming) { + indexNaming = i +} + +type defaultIndexNaming struct{} + +func (d *defaultIndexNaming) IndexName(lm *logs.LogMsg) string { + return fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()) +} diff --git a/pkg/infrastructure/logs/es/index_test.go b/pkg/infrastructure/logs/es/index_test.go new file mode 100644 index 00000000..4cdf9b02 --- /dev/null +++ b/pkg/infrastructure/logs/es/index_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package es + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +func TestDefaultIndexNaming_IndexName(t *testing.T) { + tm := time.Date(2020, 9, 12, 1, 34, 45, 234, time.UTC) + lm := &logs.LogMsg{ + When: tm, + } + + res := (&defaultIndexNaming{}).IndexName(lm) + assert.Equal(t, "2020.09.12", res) +} diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/infrastructure/logs/formatter_test.go index 7ba9237b..a97765ac 100644 --- a/pkg/infrastructure/logs/formatter_test.go +++ b/pkg/infrastructure/logs/formatter_test.go @@ -92,4 +92,4 @@ func TestPatternLogFormatter(t *testing.T) { if got != want { t.Errorf("want %s, got %s", want, got) } -} \ No newline at end of file +} From 7c8136710c02df19c9b750aa8101974c16cbd4fa Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 19 Sep 2020 23:54:33 +0800 Subject: [PATCH 167/207] Add stale.yml --- .github/workflows/stale.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/stale.yml diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..3a4d2e9a --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,19 @@ +name: Mark stale issues and pull requests + +on: + schedule: + - cron: "30 1 * * *" + +jobs: + stale: + + runs-on: ubuntu-latest + + steps: + - uses: actions/stale@v1 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: 'This issue is inactive for a long time.' + stale-pr-message: 'This PR is inactive for a long time' + stale-issue-label: 'inactive-issue' + stale-pr-label: 'inactive-pr' \ No newline at end of file From 961f300c144669d5c0985008ca3b7c05cd46dd05 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 20 Sep 2020 13:15:38 +0800 Subject: [PATCH 168/207] Fix JL tests --- pkg/infrastructure/logs/formatter_test.go | 2 +- pkg/infrastructure/logs/jianliao_test.go | 29 ----------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/infrastructure/logs/formatter_test.go index 7ba9237b..a97765ac 100644 --- a/pkg/infrastructure/logs/formatter_test.go +++ b/pkg/infrastructure/logs/formatter_test.go @@ -92,4 +92,4 @@ func TestPatternLogFormatter(t *testing.T) { if got != want { t.Errorf("want %s, got %s", want, got) } -} \ No newline at end of file +} diff --git a/pkg/infrastructure/logs/jianliao_test.go b/pkg/infrastructure/logs/jianliao_test.go index eb0ac9e1..a1b2d076 100644 --- a/pkg/infrastructure/logs/jianliao_test.go +++ b/pkg/infrastructure/logs/jianliao_test.go @@ -15,41 +15,12 @@ package logs import ( - "net/http" "testing" "time" "github.com/stretchr/testify/assert" ) -type TestHttpHandler struct { -} - -func (t *TestHttpHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - writer.Write([]byte("coming")) -} - -func TestJLWriter_WriteMsg(t *testing.T) { - // start sever - - http.Handle("/", &TestHttpHandler{}) - go http.ListenAndServe(":12124", nil) - - jl := newJLWriter() - jl.Init(`{ -"webhookurl":"http://localhost:12124/hello", -"redirecturl":"nil", -"imageurl":"a" -}`) - err := jl.WriteMsg(&LogMsg{ - Msg: "world", - }) - - jl.Flush() - jl.Destroy() - assert.Nil(t, err) -} - func TestJLWriter_Format(t *testing.T) { lg := &LogMsg{ Level: LevelDebug, From bd1cfefec726233c782a80d41904a66494aeecf9 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 8 Sep 2020 21:54:35 +0800 Subject: [PATCH 169/207] rft: Move build info to pkg --- pkg/adapter/beego.go | 4 +++- pkg/adapter/metric/prometheus.go | 15 ++++++++------- pkg/{server/web => }/build_info.go | 7 ++++++- pkg/server/web/beego.go | 3 --- pkg/server/web/config.go | 3 ++- pkg/server/web/error.go | 5 +++-- pkg/server/web/filter/prometheus/filter.go | 15 ++++++++------- 7 files changed, 30 insertions(+), 22 deletions(-) rename pkg/{server/web => }/build_info.go (88%) diff --git a/pkg/adapter/beego.go b/pkg/adapter/beego.go index efd2d4ea..eb7be3f6 100644 --- a/pkg/adapter/beego.go +++ b/pkg/adapter/beego.go @@ -15,12 +15,14 @@ package adapter import ( + "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/server/web" ) const ( + // VERSION represent beego web framework version. - VERSION = web.VERSION + VERSION = pkg.VERSION // DEV is for develop DEV = web.DEV diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go index 1d3488c6..6af2c26c 100644 --- a/pkg/adapter/metric/prometheus.go +++ b/pkg/adapter/metric/prometheus.go @@ -23,6 +23,7 @@ import ( "github.com/prometheus/client_golang/prometheus" + "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/infrastructure/logs" "github.com/astaxie/beego/pkg/server/web" ) @@ -58,13 +59,13 @@ func registerBuildInfo() { Help: "The building information", ConstLabels: map[string]string{ "appname": web.BConfig.AppName, - "build_version": web.BuildVersion, - "build_revision": web.BuildGitRevision, - "build_status": web.BuildStatus, - "build_tag": web.BuildTag, - "build_time": strings.Replace(web.BuildTime, "--", " ", 1), - "go_version": web.GoVersion, - "git_branch": web.GitBranch, + "build_version": pkg.BuildVersion, + "build_revision": pkg.BuildGitRevision, + "build_status": pkg.BuildStatus, + "build_tag": pkg.BuildTag, + "build_time": strings.Replace(pkg.BuildTime, "--", " ", 1), + "go_version": pkg.GoVersion, + "git_branch": pkg.GitBranch, "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/server/web/build_info.go b/pkg/build_info.go similarity index 88% rename from pkg/server/web/build_info.go rename to pkg/build_info.go index 53351c11..778856c6 100644 --- a/pkg/server/web/build_info.go +++ b/pkg/build_info.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package web +package pkg var ( BuildVersion string @@ -25,3 +25,8 @@ var ( GitBranch string ) + +const ( + // VERSION represent beego web framework version. + VERSION = "1.12.2" +) diff --git a/pkg/server/web/beego.go b/pkg/server/web/beego.go index 76e7b85e..42bfc8be 100644 --- a/pkg/server/web/beego.go +++ b/pkg/server/web/beego.go @@ -22,9 +22,6 @@ import ( ) const ( - // VERSION represent beego web framework version. - VERSION = "1.12.2" - // DEV is for develop DEV = "dev" // PROD is for production diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index 6e69a2fb..309f9b87 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -24,6 +24,7 @@ import ( "runtime" "strings" + "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/infrastructure/config" "github.com/astaxie/beego/pkg/infrastructure/logs" "github.com/astaxie/beego/pkg/infrastructure/session" @@ -208,7 +209,7 @@ func newBConfig() *Config { AppName: "beego", RunMode: PROD, RouterCaseSensitive: true, - ServerName: "beegoServer:" + VERSION, + ServerName: "beegoServer:" + pkg.VERSION, RecoverPanic: true, RecoverFunc: recoverPanic, CopyRequestBody: false, diff --git a/pkg/server/web/error.go b/pkg/server/web/error.go index b62fb70d..69e7ff02 100644 --- a/pkg/server/web/error.go +++ b/pkg/server/web/error.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" + "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/astaxie/beego/pkg/server/web/context" @@ -91,7 +92,7 @@ func showErr(err interface{}, ctx *context.Context, stack string) { "RequestURL": ctx.Input.URI(), "RemoteAddr": ctx.Input.IP(), "Stack": stack, - "BeegoVersion": VERSION, + "BeegoVersion": pkg.VERSION, "GoVersion": runtime.Version(), } t.Execute(ctx.ResponseWriter, data) @@ -378,7 +379,7 @@ func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errCont t, _ := template.New("beegoerrortemp").Parse(errtpl) data := M{ "Title": http.StatusText(errCode), - "BeegoVersion": VERSION, + "BeegoVersion": pkg.VERSION, "Content": template.HTML(errContent), } t.Execute(rw, data) diff --git a/pkg/server/web/filter/prometheus/filter.go b/pkg/server/web/filter/prometheus/filter.go index f4231c73..eb5b0b78 100644 --- a/pkg/server/web/filter/prometheus/filter.go +++ b/pkg/server/web/filter/prometheus/filter.go @@ -21,6 +21,7 @@ import ( "github.com/prometheus/client_golang/prometheus" + "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/server/web" "github.com/astaxie/beego/pkg/server/web/context" ) @@ -63,13 +64,13 @@ func registerBuildInfo() { Help: "The building information", ConstLabels: map[string]string{ "appname": web.BConfig.AppName, - "build_version": web.BuildVersion, - "build_revision": web.BuildGitRevision, - "build_status": web.BuildStatus, - "build_tag": web.BuildTag, - "build_time": strings.Replace(web.BuildTime, "--", " ", 1), - "go_version": web.GoVersion, - "git_branch": web.GitBranch, + "build_version": pkg.BuildVersion, + "build_revision": pkg.BuildGitRevision, + "build_status": pkg.BuildStatus, + "build_tag": pkg.BuildTag, + "build_time": strings.Replace(pkg.BuildTime, "--", " ", 1), + "go_version": pkg.GoVersion, + "git_branch": pkg.GitBranch, "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) From d455805a0a4590894f290419d8f03c6f062091f0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 9 Sep 2020 22:55:18 +0800 Subject: [PATCH 170/207] Multiple server refactor --- pkg/adapter/app.go | 12 +- pkg/server/web/beego.go | 68 ++--- pkg/server/web/config.go | 26 +- pkg/server/web/error.go | 4 +- pkg/server/web/router.go | 79 ++---- pkg/server/web/router_test.go | 2 +- pkg/server/web/{app.go => server.go} | 379 +++++++++++++++++++-------- pkg/server/web/server_test.go | 31 +++ pkg/server/web/template.go | 8 +- 9 files changed, 382 insertions(+), 227 deletions(-) rename pkg/server/web/{app.go => server.go} (54%) create mode 100644 pkg/server/web/server_test.go diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go index c1046c79..10ffa96a 100644 --- a/pkg/adapter/app.go +++ b/pkg/adapter/app.go @@ -33,11 +33,11 @@ func init() { } // App defines beego application with a new PatternServeMux. -type App web.App +type App web.HttpServer // NewApp returns a new beego application. func NewApp() *App { - return (*App)(web.NewApp()) + return (*App)(web.NewHttpSever()) } // MiddleWare function for http.Handler @@ -46,7 +46,7 @@ type MiddleWare web.MiddleWare // Run beego application. func (app *App) Run(mws ...MiddleWare) { newMws := oldMiddlewareToNew(mws) - (*web.App)(app).Run(newMws...) + (*web.HttpServer)(app).Run("", newMws...) } func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { @@ -58,7 +58,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { } // Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. +// it's an alias method of HttpServer.Router. // usage: // simple router // beego.Router("/admin", &admin.UserController{}) @@ -138,7 +138,7 @@ func RESTRouter(rootpath string, c ControllerInterface) *App { } // AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. +// it's same to HttpServer.AutoRouter. // if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, // visit the url /main/list to exec List function or /main/page to exec Page function. func AutoRouter(c ControllerInterface) *App { @@ -146,7 +146,7 @@ func AutoRouter(c ControllerInterface) *App { } // AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. +// it's same to HttpServer.AutoRouterWithPrefix. // if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, // visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. func AutoPrefix(prefix string, c ControllerInterface) *App { diff --git a/pkg/server/web/beego.go b/pkg/server/web/beego.go index 42bfc8be..14e51a94 100644 --- a/pkg/server/web/beego.go +++ b/pkg/server/web/beego.go @@ -17,8 +17,7 @@ package web import ( "os" "path/filepath" - "strconv" - "strings" + "sync" ) const ( @@ -35,7 +34,7 @@ type M map[string]interface{} type hookfunc func() error var ( - hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc + hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc ) // AddAPPStartHook is used to register the hookfunc @@ -52,56 +51,39 @@ func AddAPPStartHook(hf ...hookfunc) { // beego.Run("127.0.0.1:8089") func Run(params ...string) { - initBeforeHTTPRun() - if len(params) > 0 && params[0] != "" { - strs := strings.Split(params[0], ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BConfig.Listen.Domains = params + BeeApp.Run(params[0]) } - - BeeApp.Run() + BeeApp.Run("") } // RunWithMiddleWares Run beego application with middlewares. func RunWithMiddleWares(addr string, mws ...MiddleWare) { - initBeforeHTTPRun() - - strs := strings.Split(addr, ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - BConfig.Listen.Domains = []string{strs[0]} - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BeeApp.Run(mws...) + BeeApp.Run(addr, mws...) } -func initBeforeHTTPRun() { - //init hooks - AddAPPStartHook( - registerMime, - registerDefaultErrorHandler, - registerSession, - registerTemplate, - registerAdmin, - registerGzip, - registerCommentRouter, - ) +var initHttpOnce sync.Once - for _, hk := range hooks { - if err := hk(); err != nil { - panic(err) +// TODO move to module init function +func initBeforeHTTPRun() { + initHttpOnce.Do(func() { + // init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + registerCommentRouter, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } } - } + }) } // TestBeegoInit is for test package init diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index 309f9b87..add81b8c 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -34,6 +34,7 @@ import ( ) // Config is the main struct for BConfig +// TODO after supporting multiple servers, remove common config to somewhere else type Config struct { AppName string // Application name RunMode string // Running Mode: dev | prod @@ -168,15 +169,15 @@ func init() { } } -func recoverPanic(ctx *context.Context) { +func (cfg *Config) defaultRecoverPanic(ctx *context.Context) { if err := recover(); err != nil { if err == ErrAbort { return } - if !BConfig.RecoverPanic { + if !cfg.RecoverPanic { panic(err) } - if BConfig.EnableErrorsShow { + if cfg.EnableErrorsShow { if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { exception(fmt.Sprint(err), ctx) return @@ -193,7 +194,7 @@ func recoverPanic(ctx *context.Context) { logs.Critical(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) } - if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { + if cfg.RunMode == DEV && cfg.EnableErrorsRender { showErr(err, ctx, stack) } if ctx.Output.Status != 0 { @@ -205,18 +206,18 @@ func recoverPanic(ctx *context.Context) { } func newBConfig() *Config { - return &Config{ + res := &Config{ AppName: "beego", RunMode: PROD, RouterCaseSensitive: true, ServerName: "beegoServer:" + pkg.VERSION, RecoverPanic: true, - RecoverFunc: recoverPanic, - CopyRequestBody: false, - EnableGzip: false, - MaxMemory: 1 << 26, // 64MB - EnableErrorsShow: true, - EnableErrorsRender: true, + + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, // 64MB + EnableErrorsShow: true, + EnableErrorsRender: true, Listen: Listen{ Graceful: false, ServerTimeOut: 0, @@ -279,6 +280,9 @@ func newBConfig() *Config { Outputs: map[string]string{"console": ""}, }, } + + res.RecoverFunc = res.defaultRecoverPanic + return res } // now only support ini, next will support json. diff --git a/pkg/server/web/error.go b/pkg/server/web/error.go index 69e7ff02..a005c110 100644 --- a/pkg/server/web/error.go +++ b/pkg/server/web/error.go @@ -389,7 +389,7 @@ func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errCont // usage: // beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("500",InternalServerError) -func ErrorHandler(code string, h http.HandlerFunc) *App { +func ErrorHandler(code string, h http.HandlerFunc) *HttpServer { ErrorMaps[code] = &errorInfo{ errorType: errorTypeHandler, handler: h, @@ -401,7 +401,7 @@ func ErrorHandler(code string, h http.HandlerFunc) *App { // ErrorController registers ControllerInterface to each http err code string. // usage: // beego.ErrorController(&controllers.ErrorController{}) -func ErrorController(c ControllerInterface) *App { +func ErrorController(c ControllerInterface) *HttpServer { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() ct := reflect.Indirect(reflectVal).Type() diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index 3dd19a6f..1a0183bd 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -135,10 +135,18 @@ type ControllerRegister struct { // the filter created by FilterChain chainRoot *FilterRouter + + cfg *Config } // NewControllerRegister returns a new ControllerRegister. +// Usually you should not use this method +// please use NewControllerRegisterWithCfg func NewControllerRegister() *ControllerRegister { + return NewControllerRegisterWithCfg(BeeApp.Cfg) +} + +func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { res := &ControllerRegister{ routers: make(map[string]*Tree), policies: make(map[string]*Tree), @@ -240,7 +248,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { - if !BConfig.RouterCaseSensitive { + if !p.cfg.RouterCaseSensitive { pattern = strings.ToLower(pattern) } if t, ok := p.routers[method]; ok { @@ -453,7 +461,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { - opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) mr := newFilterRouter(pattern, filter, opts...) return p.insertFilterRouter(pos, mr) } @@ -472,7 +480,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { root := p.chainRoot filterFunc := chain(root.filterFunc) - opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) p.chainRoot.next = root @@ -669,14 +677,14 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { isRunnable bool ) - if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(ctx) + if p.cfg.RecoverFunc != nil { + defer p.cfg.RecoverFunc(ctx) } - ctx.Output.EnableGzip = BConfig.EnableGzip + ctx.Output.EnableGzip = p.cfg.EnableGzip - if BConfig.RunMode == DEV { - ctx.Output.Header("Server", BConfig.ServerName) + if p.cfg.RunMode == DEV { + ctx.Output.Header("Server", p.cfg.ServerName) } urlPath := p.getUrlPath(ctx) @@ -700,20 +708,20 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !ctx.Input.IsUpload() { + if p.cfg.CopyRequestBody && !ctx.Input.IsUpload() { // connection will close if the incoming data are larger (RFC 7231, 6.5.11) - if r.ContentLength > BConfig.MaxMemory { + if r.ContentLength > p.cfg.MaxMemory { logs.Error(errors.New("payload too large")) exception("413", ctx) goto Admin } - ctx.Input.CopyBody(BConfig.MaxMemory) + ctx.Input.CopyBody(p.cfg.MaxMemory) } - ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + ctx.Input.ParseFormOrMulitForm(p.cfg.MaxMemory) } // session init - if BConfig.WebConfig.Session.SessionOn { + if p.cfg.WebConfig.Session.SessionOn { var err error ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { @@ -819,7 +827,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { execController.Prepare() // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if BConfig.WebConfig.EnableXSRF { + if p.cfg.WebConfig.EnableXSRF { execController.XSRFToken() if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) { @@ -864,7 +872,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { // render template if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 { - if BConfig.WebConfig.AutoRender { + if p.cfg.WebConfig.AutoRender { if err := execController.Render(); err != nil { logs.Error(err) } @@ -897,7 +905,7 @@ Admin: timeDur := time.Since(startTime) ctx.ResponseWriter.Elapsed = timeDur - if BConfig.Listen.EnableAdmin { + if p.cfg.Listen.EnableAdmin { pattern := "" if routerInfo != nil { pattern = routerInfo.pattern @@ -912,7 +920,7 @@ Admin: } } - if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { + if p.cfg.RunMode == DEV && !p.cfg.Log.AccessLogs { match := map[bool]string{true: "match", false: "nomatch"} devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", ctx.Input.IP(), @@ -935,7 +943,7 @@ Admin: func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string { urlPath := ctx.Request.URL.Path - if !BConfig.RouterCaseSensitive { + if !p.cfg.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) } return urlPath @@ -958,7 +966,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex // FindRouter Find Router info for URL func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { var urlPath = context.Input.URL() - if !BConfig.RouterCaseSensitive { + if !p.cfg.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) } httpMethod := context.Input.Method() @@ -984,36 +992,5 @@ func toURL(params map[string]string) string { // LogAccess logging info HTTP Access func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { - // Skip logging if AccessLogs config is false - if !BConfig.Log.AccessLogs { - return - } - // Skip logging static requests unless EnableStaticLogs config is true - if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { - return - } - var ( - requestTime time.Time - elapsedTime time.Duration - r = ctx.Request - ) - if startTime != nil { - requestTime = *startTime - elapsedTime = time.Since(*startTime) - } - record := &logs.AccessLogRecord{ - RemoteAddr: ctx.Input.IP(), - RequestTime: requestTime, - RequestMethod: r.Method, - Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), - ServerProtocol: r.Proto, - Host: r.Host, - Status: statusCode, - ElapsedTime: elapsedTime, - HTTPReferrer: r.Header.Get("Referer"), - HTTPUserAgent: r.Header.Get("User-Agent"), - RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: r.ContentLength, - } - logs.AccessLog(record, BConfig.Log.AccessLogsFormat) + BeeApp.LogAccess(ctx, startTime, statusCode) } diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index 2863da3a..da8efbcb 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -381,7 +381,7 @@ func TestRouterHandlerAll(t *testing.T) { } // -// Benchmarks NewApp: +// Benchmarks NewHttpSever: // func beegoFilterFunc(ctx *context.Context) { diff --git a/pkg/server/web/app.go b/pkg/server/web/server.go similarity index 54% rename from pkg/server/web/app.go rename to pkg/server/web/server.go index 7511c7fe..c3ab1696 100644 --- a/pkg/server/web/app.go +++ b/pkg/server/web/server.go @@ -24,12 +24,14 @@ import ( "net/http/fcgi" "os" "path" + "strconv" "strings" "time" "golang.org/x/crypto/acme/autocert" "github.com/astaxie/beego/pkg/infrastructure/logs" + beecontext "github.com/astaxie/beego/pkg/server/web/context" "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/astaxie/beego/pkg/server/web/grace" @@ -37,24 +39,39 @@ import ( var ( // BeeApp is an application instance - BeeApp *App + // If you are using single server, you could use this + // But if you need multiple servers, do not use this + BeeApp *HttpServer ) func init() { // create beego application - BeeApp = NewApp() + BeeApp = NewHttpSever() } -// App defines beego application with a new PatternServeMux. -type App struct { +// HttpServer defines beego application with a new PatternServeMux. +type HttpServer struct { Handlers *ControllerRegister Server *http.Server + Cfg *Config } -// NewApp returns a new beego application. -func NewApp() *App { - cr := NewControllerRegister() - app := &App{Handlers: cr, Server: &http.Server{}} +// NewHttpSever returns a new beego application. +// this method will use the BConfig as the configure to create HttpServer +// Be careful that when you update BConfig, the server's Cfg will not be updated +func NewHttpSever() *HttpServer { + return NewHttpServerWithCfg(*BConfig) +} + +// NewHttpServerWithCfg will create an sever with specific cfg +func NewHttpServerWithCfg(cfg Config) *HttpServer { + cfgPtr := &cfg + cr := NewControllerRegisterWithCfg(cfgPtr) + app := &HttpServer{ + Handlers: cr, + Server: &http.Server{}, + Cfg: cfgPtr, + } return app } @@ -62,11 +79,16 @@ func NewApp() *App { type MiddleWare func(http.Handler) http.Handler // Run beego application. -func (app *App) Run(mws ...MiddleWare) { - addr := BConfig.Listen.HTTPAddr +func (app *HttpServer) Run(addr string, mws ...MiddleWare) { - if BConfig.Listen.HTTPPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) + initBeforeHTTPRun() + + app.initAddr(addr) + + addr = app.Cfg.Listen.HTTPAddr + + if app.Cfg.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPAddr, app.Cfg.Listen.HTTPPort) } var ( @@ -76,8 +98,8 @@ func (app *App) Run(mws ...MiddleWare) { ) // run cgi server - if BConfig.Listen.EnableFcgi { - if BConfig.Listen.EnableStdIo { + if app.Cfg.Listen.EnableFcgi { + if app.Cfg.Listen.EnableStdIo { if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O logs.Info("Use FCGI via standard I/O") } else { @@ -85,7 +107,7 @@ func (app *App) Run(mws ...MiddleWare) { } return } - if BConfig.Listen.HTTPPort == 0 { + if app.Cfg.Listen.HTTPPort == 0 { // remove the Socket file before start if utils.FileExists(addr) { os.Remove(addr) @@ -110,40 +132,42 @@ func (app *App) Run(mws ...MiddleWare) { } app.Server.Handler = mws[i](app.Server.Handler) } - app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ReadTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second app.Server.ErrorLog = logs.GetLogger("HTTP") // run graceful mode - if BConfig.Listen.Graceful { - httpsAddr := BConfig.Listen.HTTPSAddr + if app.Cfg.Listen.Graceful { + httpsAddr := app.Cfg.Listen.HTTPSAddr app.Server.Addr = httpsAddr - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { go func() { time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + if app.Cfg.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) app.Server.Addr = httpsAddr } server := grace.NewServer(httpsAddr, app.Server.Handler) server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + if app.Cfg.Listen.EnableMutualHTTPS { + if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile, + app.Cfg.Listen.HTTPSKeyFile, + app.Cfg.Listen.TrustCaFile); err != nil { logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) } } else { - if BConfig.Listen.AutoTLS { + if app.Cfg.Listen.AutoTLS { m := autocert.Manager{ Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...), + Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir), } app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" } - if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) } @@ -151,12 +175,12 @@ func (app *App) Run(mws ...MiddleWare) { endRunning <- true }() } - if BConfig.Listen.EnableHTTP { + if app.Cfg.Listen.EnableHTTP { go func() { server := grace.NewServer(addr, app.Server.Handler) server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.ListenTCP4 { + if app.Cfg.Listen.ListenTCP4 { server.Network = "tcp4" } if err := server.ListenAndServe(); err != nil { @@ -171,27 +195,27 @@ func (app *App) Run(mws ...MiddleWare) { } // run normal mode - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { go func() { time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - } else if BConfig.Listen.EnableHTTP { + if app.Cfg.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) + } else if app.Cfg.Listen.EnableHTTP { logs.Info("Start https server error, conflict with http. Please reset https port") return } logs.Info("https server Running on https://%s", app.Server.Addr) - if BConfig.Listen.AutoTLS { + if app.Cfg.Listen.AutoTLS { m := autocert.Manager{ Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...), + Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir), } app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } else if BConfig.Listen.EnableMutualHTTPS { + app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" + } else if app.Cfg.Listen.EnableMutualHTTPS { pool := x509.NewCertPool() - data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + data, err := ioutil.ReadFile(app.Cfg.Listen.TrustCaFile) if err != nil { logs.Info("MutualHTTPS should provide TrustCaFile") return @@ -199,10 +223,10 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.ClientAuthType(BConfig.Listen.ClientAuth), + ClientAuth: tls.ClientAuthType(app.Cfg.Listen.ClientAuth), } } - if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + if err := app.Server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true @@ -210,11 +234,11 @@ func (app *App) Run(mws ...MiddleWare) { }() } - if BConfig.Listen.EnableHTTP { + if app.Cfg.Listen.EnableHTTP { go func() { app.Server.Addr = addr logs.Info("http server Running on http://%s", app.Server.Addr) - if BConfig.Listen.ListenTCP4 { + if app.Cfg.Listen.ListenTCP4 { ln, err := net.Listen("tcp4", app.Server.Addr) if err != nil { logs.Critical("ListenAndServe: ", err) @@ -240,8 +264,17 @@ func (app *App) Run(mws ...MiddleWare) { <-endRunning } +func (app *HttpServer) Start() { + +} + +// Router see HttpServer.Router +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + return BeeApp.Router(rootpath, c, mappingMethods...) +} + // Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. +// it's an alias method of HttpServer.Router. // usage: // simple router // beego.Router("/admin", &admin.UserController{}) @@ -256,9 +289,14 @@ func (app *App) Run(mws ...MiddleWare) { // beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - BeeApp.Handlers.Add(rootpath, c, mappingMethods...) - return BeeApp +func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + app.Handlers.Add(rootPath, c, mappingMethods...) + return app +} + +// UnregisterFixedRoute see HttpServer.UnregisterFixedRoute +func UnregisterFixedRoute(fixedRoute string, method string) *HttpServer { + return BeeApp.UnregisterFixedRoute(fixedRoute, method) } // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful @@ -270,31 +308,31 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *A // Usage (replace "GET" with "*" for all methods): // beego.UnregisterFixedRoute("/yourpreviouspath", "GET") // beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") -func UnregisterFixedRoute(fixedRoute string, method string) *App { +func (app *HttpServer) UnregisterFixedRoute(fixedRoute string, method string) *HttpServer { subPaths := splitPath(fixedRoute) if method == "" || method == "*" { for m := range HTTPMETHOD { - if _, ok := BeeApp.Handlers.routers[m]; !ok { + if _, ok := app.Handlers.routers[m]; !ok { continue } - if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) + if app.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(app.Handlers.routers[m]) continue } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) + findAndRemoveTree(subPaths, app.Handlers.routers[m], m) } - return BeeApp + return app } // Single HTTP method um := strings.ToUpper(method) - if _, ok := BeeApp.Handlers.routers[um]; ok { - if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) - return BeeApp + if _, ok := app.Handlers.routers[um]; ok { + if app.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(app.Handlers.routers[um]) + return app } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) + findAndRemoveTree(subPaths, app.Handlers.routers[um], um) } - return BeeApp + return app } func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { @@ -339,6 +377,11 @@ func findAndRemoveSingleTree(entryPointTree *Tree) { } } +// Include see HttpServer.Include +func Include(cList ...ControllerInterface) *HttpServer { + return BeeApp.Include(cList...) +} + // Include will generate router file in the router/xxx.go from the controller's comments // usage: // beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) @@ -366,36 +409,56 @@ func findAndRemoveSingleTree(entryPointTree *Tree) { // the comments @router url methodlist // url support all the function Router's pattern // methodlist [get post head put delete options *] -func Include(cList ...ControllerInterface) *App { - BeeApp.Handlers.Include(cList...) - return BeeApp +func (app *HttpServer) Include(cList ...ControllerInterface) *HttpServer { + app.Handlers.Include(cList...) + return app +} + +// RESTRouter see HttpServer.RESTRouter +func RESTRouter(rootpath string, c ControllerInterface) *HttpServer { + return BeeApp.RESTRouter(rootpath, c) } // RESTRouter adds a restful controller handler to BeeApp. // its' controller implements beego.ControllerInterface and // defines a param "pattern/:objectId" to visit each resource. -func RESTRouter(rootpath string, c ControllerInterface) *App { - Router(rootpath, c) - Router(path.Join(rootpath, ":objectId"), c) - return BeeApp +func (app *HttpServer) RESTRouter(rootpath string, c ControllerInterface) *HttpServer { + app.Router(rootpath, c) + app.Router(path.Join(rootpath, ":objectId"), c) + return app +} + +// AutoRouter see HttpServer.AutoRouter +func AutoRouter(c ControllerInterface) *HttpServer { + return BeeApp.AutoRouter(c) } // AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. +// it's same to HttpServer.AutoRouter. // if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, // visit the url /main/list to exec List function or /main/page to exec Page function. -func AutoRouter(c ControllerInterface) *App { - BeeApp.Handlers.AddAuto(c) - return BeeApp +func (app *HttpServer) AutoRouter(c ControllerInterface) *HttpServer { + app.Handlers.AddAuto(c) + return app +} + +// AutoPrefix see HttpServer.AutoPrefix +func AutoPrefix(prefix string, c ControllerInterface) *HttpServer { + return BeeApp.AutoPrefix(prefix, c) } // AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. +// it's same to HttpServer.AutoRouterWithPrefix. // if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, // visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. -func AutoPrefix(prefix string, c ControllerInterface) *App { - BeeApp.Handlers.AddAutoPrefix(prefix, c) - return BeeApp +func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpServer { + app.Handlers.AddAutoPrefix(prefix, c) + return app +} + +// Get see HttpServer.Get +func Get(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Get(rootpath, f) } // Get used to register router for Get method @@ -403,9 +466,14 @@ func AutoPrefix(prefix string, c ControllerInterface) *App { // beego.Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Get(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Get(rootpath, f) - return BeeApp +func (app *HttpServer) Get(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Get(rootpath, f) + return app +} + +// Post see HttpServer.Post +func Post(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Post(rootpath, f) } // Post used to register router for Post method @@ -413,9 +481,14 @@ func Get(rootpath string, f FilterFunc) *App { // beego.Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Post(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Post(rootpath, f) - return BeeApp +func (app *HttpServer) Post(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Post(rootpath, f) + return app +} + +// Delete see HttpServer.Delete +func Delete(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Delete(rootpath, f) } // Delete used to register router for Delete method @@ -423,9 +496,14 @@ func Post(rootpath string, f FilterFunc) *App { // beego.Delete("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Delete(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Delete(rootpath, f) - return BeeApp +func (app *HttpServer) Delete(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Delete(rootpath, f) + return app +} + +// Put see HttpServer.Put +func Put(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Put(rootpath, f) } // Put used to register router for Put method @@ -433,9 +511,14 @@ func Delete(rootpath string, f FilterFunc) *App { // beego.Put("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Put(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Put(rootpath, f) - return BeeApp +func (app *HttpServer) Put(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Put(rootpath, f) + return app +} + +// Head see HttpServer.Head +func Head(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Head(rootpath, f) } // Head used to register router for Head method @@ -443,8 +526,14 @@ func Put(rootpath string, f FilterFunc) *App { // beego.Head("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Head(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Head(rootpath, f) +func (app *HttpServer) Head(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Head(rootpath, f) + return app +} + +// Options see HttpServer.Options +func Options(rootpath string, f FilterFunc) *HttpServer { + BeeApp.Handlers.Options(rootpath, f) return BeeApp } @@ -453,9 +542,14 @@ func Head(rootpath string, f FilterFunc) *App { // beego.Options("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Options(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Options(rootpath, f) - return BeeApp +func (app *HttpServer) Options(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Options(rootpath, f) + return app +} + +// Patch see HttpServer.Patch +func Patch(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Patch(rootpath, f) } // Patch used to register router for Patch method @@ -463,9 +557,14 @@ func Options(rootpath string, f FilterFunc) *App { // beego.Patch("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Patch(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Patch(rootpath, f) - return BeeApp +func (app *HttpServer) Patch(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Patch(rootpath, f) + return app +} + +// Any see HttpServer.Any +func Any(rootpath string, f FilterFunc) *HttpServer { + return BeeApp.Any(rootpath, f) } // Any used to register router for all methods @@ -473,9 +572,14 @@ func Patch(rootpath string, f FilterFunc) *App { // beego.Any("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func Any(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Any(rootpath, f) - return BeeApp +func (app *HttpServer) Any(rootpath string, f FilterFunc) *HttpServer { + app.Handlers.Any(rootpath, f) + return app +} + +// Handler see HttpServer.Handler +func Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer { + return BeeApp.Handler(rootpath, h, options...) } // Handler used to register a Handler router @@ -483,24 +587,81 @@ func Any(rootpath string, f FilterFunc) *App { // beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { // fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) // })) -func Handler(rootpath string, h http.Handler, options ...interface{}) *App { - BeeApp.Handlers.Handler(rootpath, h, options...) - return BeeApp +func (app *HttpServer) Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer { + app.Handlers.Handler(rootpath, h, options...) + return app +} + +// InserFilter see HttpServer.InsertFilter +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer { + return BeeApp.InsertFilter(pattern, pos, filter, opts...) } // InsertFilter adds a FilterFunc with pattern condition and action constant. // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) - return BeeApp +func (app *HttpServer) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer { + app.Handlers.InsertFilter(pattern, pos, filter, opts...) + return app +} + +// InsertFilterChain see HttpServer.InsertFilterChain +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer { + return BeeApp.InsertFilterChain(pattern, filterChain, opts...) } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. -// the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) - return BeeApp +// the filter's behavior like stack's behavior +// and the last filter is serving the http request +func (app *HttpServer) InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer { + app.Handlers.InsertFilterChain(pattern, filterChain, opts...) + return app +} + +func (app *HttpServer) initAddr(addr string) { + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + app.Cfg.Listen.HTTPAddr = strs[0] + app.Cfg.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + app.Cfg.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } +} + +func (app *HttpServer) LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + // Skip logging if AccessLogs config is false + if !app.Cfg.Log.AccessLogs { + return + } + // Skip logging static requests unless EnableStaticLogs config is true + if !app.Cfg.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: r.ContentLength, + } + logs.AccessLog(record, app.Cfg.Log.AccessLogsFormat) } diff --git a/pkg/server/web/server_test.go b/pkg/server/web/server_test.go new file mode 100644 index 00000000..45ab2d4f --- /dev/null +++ b/pkg/server/web/server_test.go @@ -0,0 +1,31 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpServerWithCfg(t *testing.T) { + // we should make sure that update server's config won't change + BConfig.AppName = "Before" + svr := NewHttpServerWithCfg(*BConfig) + svr.Cfg.AppName = "hello" + assert.NotEqual(t, "hello", BConfig.AppName) + assert.Equal(t, "Before", BConfig.AppName) + +} diff --git a/pkg/server/web/template.go b/pkg/server/web/template.go index a4b8db99..1192a3f2 100644 --- a/pkg/server/web/template.go +++ b/pkg/server/web/template.go @@ -368,14 +368,14 @@ func SetTemplateFSFunc(fnt templateFSFunc) { } // SetViewsPath sets view directory path in beego application. -func SetViewsPath(path string) *App { +func SetViewsPath(path string) *HttpServer { BConfig.WebConfig.ViewsPath = path return BeeApp } // SetStaticPath sets static directory path and proper url pattern in beego application. // if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". -func SetStaticPath(url string, path string) *App { +func SetStaticPath(url string, path string) *HttpServer { if !strings.HasPrefix(url, "/") { url = "/" + url } @@ -387,7 +387,7 @@ func SetStaticPath(url string, path string) *App { } // DelStaticPath removes the static folder setting in this url pattern in beego application. -func DelStaticPath(url string) *App { +func DelStaticPath(url string) *HttpServer { if !strings.HasPrefix(url, "/") { url = "/" + url } @@ -399,7 +399,7 @@ func DelStaticPath(url string) *App { } // AddTemplateEngine add a new templatePreProcessor which support extension -func AddTemplateEngine(extension string, fn templatePreProcessor) *App { +func AddTemplateEngine(extension string, fn templatePreProcessor) *HttpServer { AddTemplateExt(extension) beeTemplateEngines[extension] = fn return BeeApp From 2473e6941758a0a427c85cca45d92a53d9b287f4 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 10 Sep 2020 22:17:15 +0800 Subject: [PATCH 171/207] Rewrite admin service by using multiple server feature --- pkg/adapter/admin.go | 7 +- pkg/server/web/admin.go | 386 ++--------------------------- pkg/server/web/admin_controller.go | 305 +++++++++++++++++++++++ pkg/server/web/admin_test.go | 4 +- pkg/server/web/hooks.go | 8 - pkg/server/web/server.go | 89 +++++++ 6 files changed, 421 insertions(+), 378 deletions(-) create mode 100644 pkg/server/web/admin_controller.go diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go index 87e7259b..3127416f 100644 --- a/pkg/adapter/admin.go +++ b/pkg/adapter/admin.go @@ -17,6 +17,7 @@ package adapter import ( "time" + _ "github.com/astaxie/beego/pkg/infrastructure/governor" "github.com/astaxie/beego/pkg/server/web" ) @@ -38,11 +39,7 @@ import ( // beego.FilterMonitorFunc = MyFilterMonitor. var FilterMonitorFunc func(string, string, time.Duration, string, int) bool -func init() { - FilterMonitorFunc = web.FilterMonitorFunc -} - // PrintTree prints all registered routers. func PrintTree() M { - return (M)(web.PrintTree()) + return (M)(web.BeeApp.PrintTree()) } diff --git a/pkg/server/web/admin.go b/pkg/server/web/admin.go index f54ac9e5..148ab806 100644 --- a/pkg/server/web/admin.go +++ b/pkg/server/web/admin.go @@ -15,24 +15,12 @@ package web import ( - "bytes" - context2 "context" - "encoding/json" "fmt" "net/http" - "os" "reflect" - "strconv" - "text/template" "time" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/astaxie/beego/pkg/infrastructure/logs" - - "github.com/astaxie/beego/pkg/infrastructure/governor" - "github.com/astaxie/beego/pkg/infrastructure/utils" - "github.com/astaxie/beego/pkg/server/web/grace" "github.com/astaxie/beego/pkg/task" ) @@ -58,127 +46,23 @@ var beeAdminApp *adminApp var FilterMonitorFunc func(string, string, time.Duration, string, int) bool func init() { + c := &adminController{ + servers: make([]*HttpServer, 0, 2), + } beeAdminApp = &adminApp{ - routers: make(map[string]http.HandlerFunc), + HttpServer: NewHttpServerWithCfg(*BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Route("/", adminIndex) - beeAdminApp.Route("/qps", qpsIndex) - beeAdminApp.Route("/prof", profIndex) - beeAdminApp.Route("/healthcheck", healthcheck) - beeAdminApp.Route("/task", taskStatus) - beeAdminApp.Route("/listconf", listConf) - beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) + beeAdminApp.Router("/", c, "get:AdminIndex") + beeAdminApp.Router("/qps", c, "get:QpsIndex") + beeAdminApp.Router("/prof", c, "get:ProfIndex") + beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") + beeAdminApp.Router("/task", c, "get:TaskStatus") + beeAdminApp.Router("/listconf", c, "get:ListConf") + beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } -} -// AdminIndex is the default http.Handler for admin module. -// it matches url pattern "/". -func adminIndex(rw http.ResponseWriter, _ *http.Request) { - writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) -} - -// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. -// it's registered with url pattern "/qps" in admin module. -func qpsIndex(rw http.ResponseWriter, _ *http.Request) { - data := make(map[interface{}]interface{}) - data["Content"] = StatisticsMap.GetMap() - - // do html escape before display path, avoid xss - if content, ok := (data["Content"]).(M); ok { - if resultLists, ok := (content["Data"]).([][]string); ok { - for i := range resultLists { - if len(resultLists[i]) > 0 { - resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) - } - } - } - } - - writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) -} - -// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. -// it's registered with url pattern "/listconf" in admin module. -func listConf(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - rw.Write([]byte("command not support")) - return - } - - data := make(map[interface{}]interface{}) - switch command { - case "conf": - m := make(M) - list("BConfig", BConfig, m) - m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) - m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(configTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - - data["Content"] = m - - tmpl.Execute(rw, data) - - case "router": - content := PrintTree() - content["Fields"] = []string{ - "Router Pattern", - "Methods", - "Controller", - } - data["Content"] = content - data["Title"] = "Routers" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - case "filter": - var ( - content = M{ - "Fields": []string{ - "Router Pattern", - "Filter Function", - }, - } - filterTypes = []string{} - filterTypeData = make(M) - ) - - if BeeApp.Handlers.enableFilter { - var filterType string - for k, fr := range map[int]string{ - BeforeStatic: "Before Static", - BeforeRouter: "Before Router", - BeforeExec: "Before Exec", - AfterExec: "After Exec", - FinishRouter: "Finish Router"} { - if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { - filterType = fr - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - var result = []string{ - // void xss - template.HTMLEscapeString(f.pattern), - template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - } - } - - content["Data"] = filterTypeData - content["Methods"] = filterTypes - - data["Content"] = content - data["Title"] = "Filters" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - default: - rw.Write([]byte("command not support")) - } + beeAdminApp.Run() } func list(root string, p interface{}, m M) { @@ -203,239 +87,19 @@ func list(root string, p interface{}, m M) { } } -// PrintTree prints all registered routers. -func PrintTree() M { - var ( - content = M{} - methods = []string{} - methodsData = make(M) - ) - for method, t := range BeeApp.Handlers.routers { - - resultList := new([][]string) - - printTree(resultList, t) - - methods = append(methods, template.HTMLEscapeString(method)) - methodsData[template.HTMLEscapeString(method)] = resultList - } - - content["Data"] = methodsData - content["Methods"] = methods - return content -} - -func printTree(resultList *[][]string, t *Tree) { - for _, tr := range t.fixrouters { - printTree(resultList, tr) - } - if t.wildcard != nil { - printTree(resultList, t.wildcard) - } - for _, l := range t.leaves { - if v, ok := l.runObject.(*ControllerInfo); ok { - if v.routerType == routerTypeBeego { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - template.HTMLEscapeString(v.controllerType.String()), - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeRESTFul { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - "", - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeHandler { - var result = []string{ - template.HTMLEscapeString(v.pattern), - "", - "", - } - *resultList = append(*resultList, result) - } - } - } -} - -// ProfIndex is a http.Handler for showing profile command. -// it's in url pattern "/prof" in admin module. -func profIndex(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - return - } - - var ( - format = r.Form.Get("format") - data = make(map[interface{}]interface{}) - result bytes.Buffer - ) - governor.ProcessInput(command, &result) - data["Content"] = template.HTMLEscapeString(result.String()) - - if format == "json" && command == "gc summary" { - dataJSON, err := json.Marshal(data) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(rw, dataJSON) - return - } - - data["Title"] = template.HTMLEscapeString(command) - defaultTpl := defaultScriptsTpl - if command == "gc summary" { - defaultTpl = gcAjaxTpl - } - writeTemplate(rw, data, profillingTpl, defaultTpl) -} - -// Healthcheck is a http.Handler calling health checking and showing the result. -// it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, r *http.Request) { - var ( - result []string - data = make(map[interface{}]interface{}) - resultList = new([][]string) - content = M{ - "Fields": []string{"Name", "Message", "Status"}, - } - ) - - for name, h := range governor.AdminCheckList { - if err := h.Check(); err != nil { - result = []string{ - "error", - template.HTMLEscapeString(name), - template.HTMLEscapeString(err.Error()), - } - } else { - result = []string{ - "success", - template.HTMLEscapeString(name), - "OK", - } - } - *resultList = append(*resultList, result) - } - - queryParams := r.URL.Query() - jsonFlag := queryParams.Get("json") - shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) - - if shouldReturnJSON { - response := buildHealthCheckResponseList(resultList) - jsonResponse, err := json.Marshal(response) - - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } else { - writeJSON(rw, jsonResponse) - } - return - } - - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Health Check" - - writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) -} - -func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { - response := make([]map[string]interface{}, len(*healthCheckResults)) - - for i, healthCheckResult := range *healthCheckResults { - currentResultMap := make(map[string]interface{}) - - currentResultMap["name"] = healthCheckResult[0] - currentResultMap["message"] = healthCheckResult[1] - currentResultMap["status"] = healthCheckResult[2] - - response[i] = currentResultMap - } - - return response - -} - func writeJSON(rw http.ResponseWriter, jsonData []byte) { rw.Header().Set("Content-Type", "application/json") rw.Write(jsonData) } -// TaskStatus is a http.Handler with running task status (task name, status and the last execution). -// it's in "/task" pattern in admin module. -func taskStatus(rw http.ResponseWriter, req *http.Request) { - data := make(map[interface{}]interface{}) - - // Run Task - req.ParseForm() - taskname := req.Form.Get("taskname") - if taskname != "" { - if t, ok := task.AdminTaskList[taskname]; ok { - if err := t.Run(nil); err != nil { - data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} - } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus(nil)))} - } else { - data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} - } - } - - // List Tasks - content := make(M) - resultList := new([][]string) - var fields = []string{ - "Task Name", - "Task Spec", - "Task Status", - "Last Time", - "", - } - for tname, tk := range task.AdminTaskList { - result := []string{ - template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec(nil)), - template.HTMLEscapeString(tk.GetStatus(nil)), - template.HTMLEscapeString(tk.GetPrev(context2.Background()).String()), - } - *resultList = append(*resultList, result) - } - - content["Fields"] = fields - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Tasks" - writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) -} - -func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - for _, tpl := range tpls { - tmpl = template.Must(tmpl.Parse(tpl)) - } - tmpl.Execute(rw, data) -} - // adminApp is an http.HandlerFunc map used as beeAdminApp. type adminApp struct { - routers map[string]http.HandlerFunc + *HttpServer } // Route adds http.HandlerFunc to adminApp with url pattern. -func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { - admin.routers[pattern] = f -} - -// Run adminApp http server. -// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. func (admin *adminApp) Run() { + if len(task.AdminTaskList) > 0 { task.StartTask() } @@ -444,18 +108,14 @@ func (admin *adminApp) Run() { if BConfig.Listen.AdminPort != 0 { addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) } - for p, f := range admin.routers { - http.Handle(p, f) - } - logs.Info("Admin server Running on %s", addr) - var err error - if BConfig.Listen.Graceful { - err = grace.ListenAndServe(addr, nil) - } else { - err = http.ListenAndServe(addr, nil) - } - if err != nil { - logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - } + logs.Info("Admin server Running on %s", addr) + admin.HttpServer.Run(addr) +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil } diff --git a/pkg/server/web/admin_controller.go b/pkg/server/web/admin_controller.go new file mode 100644 index 00000000..c53c54cf --- /dev/null +++ b/pkg/server/web/admin_controller.go @@ -0,0 +1,305 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "bytes" + context2 "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "text/template" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/task" +) + +type adminController struct { + Controller + servers []*HttpServer +} + +func (a *adminController) registerHttpServer(svr *HttpServer) { + a.servers = append(a.servers, svr) +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func (a *adminController) ProfIndex() { + rw, r := a.Ctx.ResponseWriter, a.Ctx.Request + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + governor.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +func (a *adminController) PrometheusMetrics() { + promhttp.Handler().ServeHTTP(a.Ctx.ResponseWriter, a.Ctx.Request) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func (a *adminController) TaskStatus() { + + rw, req := a.Ctx.ResponseWriter, a.Ctx.Request + + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + if t, ok := task.AdminTaskList[taskname]; ok { + if err := t.Run(nil); err != nil { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} + } + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus(nil)))} + } else { + data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + } + } + + // List Tasks + content := make(M) + resultList := new([][]string) + var fields = []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + for tname, tk := range task.AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec(nil)), + template.HTMLEscapeString(tk.GetStatus(nil)), + template.HTMLEscapeString(tk.GetPrev(context2.Background()).String()), + } + *resultList = append(*resultList, result) + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func (a *adminController) AdminIndex() { + // AdminIndex is the default http.Handler for admin module. + // it matches url pattern "/". + writeTemplate(a.Ctx.ResponseWriter, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func (a *adminController) Healthcheck() { + heathCheck(a.Ctx.ResponseWriter, a.Ctx.Request) +} + +func heathCheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range governor.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func (a *adminController) QpsIndex() { + data := make(map[interface{}]interface{}) + data["Content"] = StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + writeTemplate(a.Ctx.ResponseWriter, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func (a *adminController) ListConf() { + rw := a.Ctx.ResponseWriter + r := a.Ctx.Request + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["appConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["appConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := BeeApp.PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + ) + + filterTypeData := BeeApp.reportFilter() + + filterTypes := make([]string, 0, len(filterTypeData)) + for k, _ := range filterTypeData { + filterTypes = append(filterTypes, k) + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response + +} + +// PrintTree print all routers +// Deprecated using BeeApp directly +func PrintTree() M { + return BeeApp.PrintTree() +} diff --git a/pkg/server/web/admin_test.go b/pkg/server/web/admin_test.go index acc67aeb..d04ac319 100644 --- a/pkg/server/web/admin_test.go +++ b/pkg/server/web/admin_test.go @@ -136,7 +136,7 @@ func TestHealthCheckHandlerDefault(t *testing.T) { w := httptest.NewRecorder() - handler := http.HandlerFunc(healthcheck) + handler := http.HandlerFunc(heathCheck) handler.ServeHTTP(w, req) @@ -197,7 +197,7 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { w := httptest.NewRecorder() - handler := http.HandlerFunc(healthcheck) + handler := http.HandlerFunc(heathCheck) handler.ServeHTTP(w, req) if status := w.Code; status != http.StatusOK { diff --git a/pkg/server/web/hooks.go b/pkg/server/web/hooks.go index 080b2006..2f0cb159 100644 --- a/pkg/server/web/hooks.go +++ b/pkg/server/web/hooks.go @@ -9,7 +9,6 @@ import ( "github.com/astaxie/beego/pkg/infrastructure/logs" "github.com/astaxie/beego/pkg/infrastructure/session" - "github.com/astaxie/beego/pkg/server/web/context" ) @@ -87,13 +86,6 @@ func registerTemplate() error { return nil } -func registerAdmin() error { - if BConfig.Listen.EnableAdmin { - go beeAdminApp.Run() - } - return nil -} - func registerGzip() error { if BConfig.EnableGzip { context.InitGzip( diff --git a/pkg/server/web/server.go b/pkg/server/web/server.go index c3ab1696..2e91e33c 100644 --- a/pkg/server/web/server.go +++ b/pkg/server/web/server.go @@ -26,6 +26,7 @@ import ( "path" "strconv" "strings" + "text/template" "time" "golang.org/x/crypto/acme/autocert" @@ -72,6 +73,7 @@ func NewHttpServerWithCfg(cfg Config) *HttpServer { Server: &http.Server{}, Cfg: cfgPtr, } + return app } @@ -665,3 +667,90 @@ func (app *HttpServer) LogAccess(ctx *beecontext.Context, startTime *time.Time, } logs.AccessLog(record, app.Cfg.Log.AccessLogsFormat) } + +// PrintTree prints all registered routers. +func (app *HttpServer) PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range app.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + var result = []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +func (app *HttpServer) reportFilter() M { + filterTypeData := make(M) + // filterTypes := []string{} + if app.Handlers.enableFilter { + // var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router", + } { + if bf := app.Handlers.filters[k]; len(bf) > 0 { + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[fr] = resultList + } + } + } + + return filterTypeData +} From e6a257f9870921b32c5bde0e041ab9f2bc68c8c7 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 20 Sep 2020 15:36:08 +0800 Subject: [PATCH 172/207] Fix BUG --- pkg/server/web/admin.go | 33 ++++++++++++++++++--------------- pkg/server/web/config.go | 6 +++--- pkg/server/web/router.go | 3 ++- pkg/server/web/server.go | 4 ---- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pkg/server/web/admin.go b/pkg/server/web/admin.go index 148ab806..46c0f738 100644 --- a/pkg/server/web/admin.go +++ b/pkg/server/web/admin.go @@ -46,23 +46,9 @@ var beeAdminApp *adminApp var FilterMonitorFunc func(string, string, time.Duration, string, int) bool func init() { - c := &adminController{ - servers: make([]*HttpServer, 0, 2), - } - beeAdminApp = &adminApp{ - HttpServer: NewHttpServerWithCfg(*BConfig), - } - // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Router("/", c, "get:AdminIndex") - beeAdminApp.Router("/qps", c, "get:QpsIndex") - beeAdminApp.Router("/prof", c, "get:ProfIndex") - beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") - beeAdminApp.Router("/task", c, "get:TaskStatus") - beeAdminApp.Router("/listconf", c, "get:ListConf") - beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } - beeAdminApp.Run() } func list(root string, p interface{}, m M) { @@ -110,11 +96,28 @@ func (admin *adminApp) Run() { } logs.Info("Admin server Running on %s", addr) + admin.HttpServer.Run(addr) } func registerAdmin() error { if BConfig.Listen.EnableAdmin { + + c := &adminController{ + servers: make([]*HttpServer, 0, 2), + } + beeAdminApp = &adminApp{ + HttpServer: NewHttpServerWithCfg(*BConfig), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Router("/", c, "get:AdminIndex") + beeAdminApp.Router("/qps", c, "get:QpsIndex") + beeAdminApp.Router("/prof", c, "get:ProfIndex") + beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") + beeAdminApp.Router("/task", c, "get:TaskStatus") + beeAdminApp.Router("/listconf", c, "get:ListConf") + beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") + go beeAdminApp.Run() } return nil diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index add81b8c..bc46b20e 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -41,7 +41,7 @@ type Config struct { RouterCaseSensitive bool ServerName string RecoverPanic bool - RecoverFunc func(*context.Context) + RecoverFunc func(*context.Context, *Config) CopyRequestBody bool EnableGzip bool MaxMemory int64 @@ -169,7 +169,7 @@ func init() { } } -func (cfg *Config) defaultRecoverPanic(ctx *context.Context) { +func defaultRecoverPanic(ctx *context.Context, cfg *Config) { if err := recover(); err != nil { if err == ErrAbort { return @@ -281,7 +281,7 @@ func newBConfig() *Config { }, } - res.RecoverFunc = res.defaultRecoverPanic + res.RecoverFunc = defaultRecoverPanic return res } diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index 1a0183bd..a9d1b0cf 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -155,6 +155,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { return beecontext.NewContext() }, }, + cfg: cfg, } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res @@ -678,7 +679,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { ) if p.cfg.RecoverFunc != nil { - defer p.cfg.RecoverFunc(ctx) + defer p.cfg.RecoverFunc(ctx, p.cfg) } ctx.Output.EnableGzip = p.cfg.EnableGzip diff --git a/pkg/server/web/server.go b/pkg/server/web/server.go index 2e91e33c..7bd9023d 100644 --- a/pkg/server/web/server.go +++ b/pkg/server/web/server.go @@ -266,10 +266,6 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { <-endRunning } -func (app *HttpServer) Start() { - -} - // Router see HttpServer.Router func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { return BeeApp.Router(rootpath, c, mappingMethods...) From 2846043f2a742262057539d2cd7a9f176377a42b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 20 Sep 2020 14:27:30 +0000 Subject: [PATCH 173/207] Fix UT --- pkg/server/web/router_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index da8efbcb..a59cde8b 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -731,6 +731,8 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { BConfig.CopyRequestBody = true BConfig.MaxMemory = 20 + BeeApp.Cfg.CopyRequestBody = true + BeeApp.Cfg.MaxMemory = 20 b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) r, _ := http.NewRequest("POST", "/user/123", b) w := httptest.NewRecorder() From 44127edefc069882605031afcd2e310b24548bc7 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 20 Sep 2020 14:18:13 +0000 Subject: [PATCH 174/207] design Command for governor module & decouple web module from task module --- pkg/infrastructure/governor/command.go | 87 +++++++++++++++++++ pkg/server/web/admin.go | 10 ++- pkg/server/web/admin_controller.go | 28 +++---- pkg/task/govenor_command.go | 92 ++++++++++++++++++++ pkg/task/governor_command_test.go | 111 +++++++++++++++++++++++++ pkg/task/task.go | 8 +- 6 files changed, 311 insertions(+), 25 deletions(-) create mode 100644 pkg/infrastructure/governor/command.go create mode 100644 pkg/task/govenor_command.go create mode 100644 pkg/task/governor_command_test.go diff --git a/pkg/infrastructure/governor/command.go b/pkg/infrastructure/governor/command.go new file mode 100644 index 00000000..75df5815 --- /dev/null +++ b/pkg/infrastructure/governor/command.go @@ -0,0 +1,87 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package governor + +import ( + "github.com/pkg/errors" +) + +// Command is an experimental interface +// We try to use this to decouple modules +// All other modules depends on this, and they register the command they support +// We may change the API in the future, so be careful about this. +type Command interface { + Execute(params ...interface{}) *Result +} + +var CommandNotFound = errors.New("Command not found") + +type Result struct { + // Status is the same as http.Status + Status int + Error error + Content interface{} +} + +func (r *Result) IsSuccess() bool { + return r.Status >= 200 && r.Status < 300 +} + +// CommandRegistry stores all commands +// name => command +type moduleCommands map[string]Command + +// Get returns command with the name +func (m moduleCommands) Get(name string) Command { + c, ok := m[name] + if ok { + return c + } + return &doNothingCommand{} +} + +// module name => moduleCommand +type commandRegistry map[string]moduleCommands + +// Get returns module's commands +func (c commandRegistry) Get(moduleName string) moduleCommands { + if mcs, ok := c[moduleName]; ok { + return mcs + } + res := make(moduleCommands) + c[moduleName] = res + return res +} + +var cmdRegistry = make(commandRegistry) + +// RegisterCommand is not thread-safe +// do not use it in concurrent case +func RegisterCommand(module string, commandName string, command Command) { + cmdRegistry.Get(module)[commandName] = command +} + +func GetCommand(module string, cmdName string) Command { + return cmdRegistry.Get(module).Get(cmdName) +} + +type doNothingCommand struct{} + +func (d *doNothingCommand) Execute(params ...interface{}) *Result { + return &Result{ + Status: 404, + Error: CommandNotFound, + } +} diff --git a/pkg/server/web/admin.go b/pkg/server/web/admin.go index 46c0f738..084190a9 100644 --- a/pkg/server/web/admin.go +++ b/pkg/server/web/admin.go @@ -21,7 +21,6 @@ import ( "time" "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/task" ) // BeeAdminApp is the default adminApp used by admin module. @@ -86,9 +85,12 @@ type adminApp struct { // Route adds http.HandlerFunc to adminApp with url pattern. func (admin *adminApp) Run() { - if len(task.AdminTaskList) > 0 { - task.StartTask() - } + // if len(task.AdminTaskList) > 0 { + // task.StartTask() + // } + logs.Warning("now we don't start tasks here, if you use task module," + + " please invoke task.StartTask, or task will not be executed") + addr := BConfig.Listen.AdminAddr if BConfig.Listen.AdminPort != 0 { diff --git a/pkg/server/web/admin_controller.go b/pkg/server/web/admin_controller.go index c53c54cf..dc3a40b5 100644 --- a/pkg/server/web/admin_controller.go +++ b/pkg/server/web/admin_controller.go @@ -16,7 +16,6 @@ package web import ( "bytes" - context2 "context" "encoding/json" "fmt" "net/http" @@ -26,7 +25,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/astaxie/beego/pkg/infrastructure/governor" - "github.com/astaxie/beego/pkg/task" ) type adminController struct { @@ -90,19 +88,22 @@ func (a *adminController) TaskStatus() { req.ParseForm() taskname := req.Form.Get("taskname") if taskname != "" { - if t, ok := task.AdminTaskList[taskname]; ok { - if err := t.Run(nil); err != nil { - data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} - } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus(nil)))} + cmd := governor.GetCommand("task", "run") + res := cmd.Execute(taskname) + if res.IsSuccess() { + + data["Message"] = []string{"success", + template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", + taskname, res.Content.(string)))} + } else { - data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", res.Error))} } } // List Tasks content := make(M) - resultList := new([][]string) + resultList := governor.GetCommand("task", "list").Execute().Content.([][]string) var fields = []string{ "Task Name", "Task Spec", @@ -110,15 +111,6 @@ func (a *adminController) TaskStatus() { "Last Time", "", } - for tname, tk := range task.AdminTaskList { - result := []string{ - template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec(nil)), - template.HTMLEscapeString(tk.GetStatus(nil)), - template.HTMLEscapeString(tk.GetPrev(context2.Background()).String()), - } - *resultList = append(*resultList, result) - } content["Fields"] = fields content["Data"] = resultList diff --git a/pkg/task/govenor_command.go b/pkg/task/govenor_command.go new file mode 100644 index 00000000..fff08374 --- /dev/null +++ b/pkg/task/govenor_command.go @@ -0,0 +1,92 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "fmt" + "html/template" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +type listTaskCommand struct { +} + +func (l *listTaskCommand) Execute(params ...interface{}) *governor.Result { + resultList := make([][]string, 0, len(AdminTaskList)) + for tname, tk := range AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec(nil)), + template.HTMLEscapeString(tk.GetStatus(nil)), + template.HTMLEscapeString(tk.GetPrev(context.Background()).String()), + } + resultList = append(resultList, result) + } + + return &governor.Result{ + Status: 200, + Content: resultList, + } +} + +type runTaskCommand struct { +} + +func (r *runTaskCommand) Execute(params ...interface{}) *governor.Result { + if len(params) == 0 { + return &governor.Result{ + Status: 400, + Error: errors.New("task name not passed"), + } + } + + tn, ok := params[0].(string) + + if !ok { + return &governor.Result{ + Status: 400, + Error: errors.New("parameter is invalid"), + } + } + + if t, ok := AdminTaskList[tn]; ok { + err := t.Run(context.Background()) + if err != nil { + return &governor.Result{ + Status: 500, + Error: err, + } + } + return &governor.Result{ + Status: 200, + Content: t.GetStatus(context.Background()), + } + } else { + return &governor.Result{ + Status: 400, + Error: errors.New(fmt.Sprintf("task with name %s not found", tn)), + } + } + +} + +func registerCommands() { + governor.RegisterCommand("task", "list", &listTaskCommand{}) + governor.RegisterCommand("task", "run", &runTaskCommand{}) +} diff --git a/pkg/task/governor_command_test.go b/pkg/task/governor_command_test.go new file mode 100644 index 00000000..00ed37f2 --- /dev/null +++ b/pkg/task/governor_command_test.go @@ -0,0 +1,111 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type countTask struct { + cnt int + mockErr error +} + +func (c *countTask) GetSpec(ctx context.Context) string { + return "AAA" +} + +func (c *countTask) GetStatus(ctx context.Context) string { + return "SUCCESS" +} + +func (c *countTask) Run(ctx context.Context) error { + c.cnt++ + return c.mockErr +} + +func (c *countTask) SetNext(ctx context.Context, time time.Time) { +} + +func (c *countTask) GetNext(ctx context.Context) time.Time { + return time.Now() +} + +func (c *countTask) SetPrev(ctx context.Context, time time.Time) { +} + +func (c *countTask) GetPrev(ctx context.Context) time.Time { + return time.Now() +} + +func TestRunTaskCommand_Execute(t *testing.T) { + task := &countTask{} + AddTask("count", task) + + cmd := &runTaskCommand{} + + res := cmd.Execute() + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "task name not passed", res.Error.Error()) + + res = cmd.Execute(10) + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "parameter is invalid", res.Error.Error()) + + res = cmd.Execute("CCCC") + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "task with name CCCC not found", res.Error.Error()) + + res = cmd.Execute("count") + assert.NotNil(t, res) + assert.True(t, res.IsSuccess()) + + task.mockErr = errors.New("mock error") + res = cmd.Execute("count") + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "mock error", res.Error.Error()) +} + +func TestListTaskCommand_Execute(t *testing.T) { + task := &countTask{} + + cmd := &listTaskCommand{} + + res := cmd.Execute() + + assert.True(t, res.IsSuccess()) + + _, ok := res.Content.([][]string) + assert.True(t, ok) + + AddTask("count", task) + + res = cmd.Execute() + + assert.True(t, res.IsSuccess()) + + rl, ok := res.Content.([][]string) + assert.True(t, ok) + assert.Equal(t, 1, len(rl)) +} diff --git a/pkg/task/task.go b/pkg/task/task.go index bcadb956..e3a8bba4 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -189,9 +189,9 @@ func (t *Task) GetPrev(context.Context) time.Time { // SetCron some signals: // *: any time // ,:  separate signal -//   -:duration +//    -:duration // /n : do as n times of time duration -///////////////////////////////////////////////////////// +// /////////////////////////////////////////////////////// // 0/30 * * * * * every 30s // 0 43 21 * * * 21:43 // 0 15 05 * * *    05:15 @@ -401,10 +401,12 @@ func StartTask() { taskLock.Lock() defer taskLock.Unlock() if isstart { - //If already started, no need to start another goroutine. + // If already started, no need to start another goroutine. return } isstart = true + + registerCommands() go run() } From 03498529b9a6bccc6571811dcf8c4e4808fe0b02 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 22 Sep 2020 22:58:58 +0800 Subject: [PATCH 175/207] Decouple web module from cache module --- pkg/server/web/captcha/captcha.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/pkg/server/web/captcha/captcha.go b/pkg/server/web/captcha/captcha.go index 2ae1fb8f..2c60f23a 100644 --- a/pkg/server/web/captcha/captcha.go +++ b/pkg/server/web/captcha/captcha.go @@ -68,7 +68,6 @@ import ( "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/client/cache" "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/astaxie/beego/pkg/server/web" "github.com/astaxie/beego/pkg/server/web/context" @@ -91,7 +90,7 @@ const ( // Captcha struct type Captcha struct { // beego cache store - store cache.Cache + store Storage // url prefix for captcha image URLPrefix string @@ -232,7 +231,7 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { } // NewCaptcha create a new captcha.Captcha -func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { +func NewCaptcha(urlPrefix string, store Storage) *Captcha { cpt := &Captcha{} cpt.store = store cpt.FieldIDName = fieldIDName @@ -258,7 +257,7 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { // NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image // and add a template func for output html -func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { +func NewWithFilter(urlPrefix string, store Storage) *Captcha { cpt := NewCaptcha(urlPrefix, store) // create filter for serve captcha image @@ -269,3 +268,12 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { return cpt } + +type Storage interface { + // Get a cached value by key. + Get(key string) interface{} + // Set a cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // Delete cached value by key. + Delete(key string) error +} From 463e96447a7c2db4a14ea23a622449e6d8ab65ca Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 27 Sep 2020 00:37:46 +0800 Subject: [PATCH 176/207] decouple httplib from web module --- pkg/client/httplib/filter/prometheus/filter.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/client/httplib/filter/prometheus/filter.go b/pkg/client/httplib/filter/prometheus/filter.go index 917d4720..b4a418e0 100644 --- a/pkg/client/httplib/filter/prometheus/filter.go +++ b/pkg/client/httplib/filter/prometheus/filter.go @@ -23,11 +23,13 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/astaxie/beego/pkg/client/httplib" - "github.com/astaxie/beego/pkg/server/web" ) type FilterChainBuilder struct { summaryVec prometheus.ObserverVec + AppName string + ServerName string + RunMode string } func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { @@ -36,9 +38,9 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt Name: "beego", Subsystem: "remote_http_request", ConstLabels: map[string]string{ - "server": web.BConfig.ServerName, - "env": web.BConfig.RunMode, - "appname": web.BConfig.AppName, + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, }, Help: "The statics info for remote http requests", }, []string{"proto", "scheme", "method", "host", "path", "status", "duration", "isError"}) From dd3f1ce9be8c6d33ad4ad379d4aadcdf000e1e15 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 27 Sep 2020 00:44:02 +0800 Subject: [PATCH 177/207] decouple httplib from config --- pkg/client/httplib/testing/client.go | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pkg/client/httplib/testing/client.go b/pkg/client/httplib/testing/client.go index 00fa3059..107b28cc 100644 --- a/pkg/client/httplib/testing/client.go +++ b/pkg/client/httplib/testing/client.go @@ -16,8 +16,6 @@ package testing import ( "github.com/astaxie/beego/pkg/client/httplib" - - "github.com/astaxie/beego/pkg/infrastructure/config" ) var port = "" @@ -28,16 +26,13 @@ type TestHTTPRequest struct { httplib.BeegoHTTPRequest } +func SetTestingPort(p string) { + port = p +} + func getPort() string { if port == "" { - config, err := config.NewConfig("ini", "../conf/app.conf") - if err != nil { - return "8080" - } - port, err = config.String(nil, "httpport") - if err != nil { - return "8080" - } + port = "8080" return port } return port From c5d43e87fe89c7beebf202e646be436996b6696d Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 4 Oct 2020 22:16:19 +0800 Subject: [PATCH 178/207] seperate orm alone --- pkg/client/orm/filter/prometheus/filter.go | 22 ++++++++----------- .../orm/filter/prometheus/filter_test.go | 13 ++++++----- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/pkg/client/orm/filter/prometheus/filter.go b/pkg/client/orm/filter/prometheus/filter.go index 175b26be..2d819ef7 100644 --- a/pkg/client/orm/filter/prometheus/filter.go +++ b/pkg/client/orm/filter/prometheus/filter.go @@ -23,7 +23,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/server/web" ) // FilterChainBuilder is an extension point, @@ -35,27 +34,24 @@ import ( // actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" type FilterChainBuilder struct { summaryVec prometheus.ObserverVec + AppName string + ServerName string + RunMode string } -func NewFilterChainBuilder() *FilterChainBuilder { - summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + + builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", Subsystem: "orm_operation", ConstLabels: map[string]string{ - "server": web.BConfig.ServerName, - "env": web.BConfig.RunMode, - "appname": web.BConfig.AppName, + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, }, Help: "The statics info for orm operation", }, []string{"method", "name", "duration", "insideTx", "txName"}) - prometheus.MustRegister(summaryVec) - return &FilterChainBuilder{ - summaryVec: summaryVec, - } -} - -func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { return func(ctx context.Context, inv *orm.Invocation) []interface{} { startTime := time.Now() res := next(ctx, inv) diff --git a/pkg/client/orm/filter/prometheus/filter_test.go b/pkg/client/orm/filter/prometheus/filter_test.go index 1b55b989..0368d321 100644 --- a/pkg/client/orm/filter/prometheus/filter_test.go +++ b/pkg/client/orm/filter/prometheus/filter_test.go @@ -24,14 +24,15 @@ import ( "github.com/astaxie/beego/pkg/client/orm" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { - builder := NewFilterChainBuilder() - assert.NotNil(t, builder.summaryVec) - - filter := builder.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { +func TestFilterChainBuilder_FilterChain1(t *testing.T) { + next := func(ctx context.Context, inv *orm.Invocation) []interface{} { inv.Method = "coming" return []interface{}{} - }) + } + builder := &FilterChainBuilder{} + filter := builder.FilterChain(next) + + assert.NotNil(t, builder.summaryVec) assert.NotNil(t, filter) inv := &orm.Invocation{} From 3364c609de5595f9e9e4e1ac02348bf31f9ad7cc Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 4 Oct 2020 22:11:28 +0800 Subject: [PATCH 179/207] Add context to cache API --- pkg/client/cache/cache.go | 17 ++--- pkg/client/cache/cache_test.go | 80 +++++++++++----------- pkg/client/cache/file.go | 54 +++++++++------ pkg/client/cache/memcache/memcache.go | 40 +++++------ pkg/client/cache/memcache/memcache_test.go | 46 +++++++------ pkg/client/cache/memory.go | 54 ++++++++------- pkg/client/cache/redis/redis.go | 32 ++++----- pkg/client/cache/redis/redis_test.go | 46 +++++++------ pkg/client/cache/ssdb/ssdb.go | 44 ++++++------ pkg/client/cache/ssdb/ssdb_test.go | 56 ++++++++------- pkg/server/web/captcha/captcha.go | 20 +++--- 11 files changed, 259 insertions(+), 230 deletions(-) diff --git a/pkg/client/cache/cache.go b/pkg/client/cache/cache.go index 049fb758..ddf246ab 100644 --- a/pkg/client/cache/cache.go +++ b/pkg/client/cache/cache.go @@ -32,6 +32,7 @@ package cache import ( + "context" "fmt" "time" ) @@ -48,21 +49,21 @@ import ( // count := c.Get("counter").(int) type Cache interface { // Get a cached value by key. - Get(key string) interface{} + Get(ctx context.Context, key string) (interface{}, error) // GetMulti is a batch version of Get. - GetMulti(keys []string) []interface{} + GetMulti(ctx context.Context, keys []string) ([]interface{}, error) // Set a cached value with key and expire time. - Put(key string, val interface{}, timeout time.Duration) error + Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error // Delete cached value by key. - Delete(key string) error + Delete(ctx context.Context, key string) error // Increment a cached int value by key, as a counter. - Incr(key string) error + Incr(ctx context.Context, key string) error // Decrement a cached int value by key, as a counter. - Decr(key string) error + Decr(ctx context.Context, key string) error // Check if a cached value exists or not. - IsExist(key string) bool + IsExist(ctx context.Context, key string) (bool, error) // Clear all cache. - ClearAll() error + ClearAll(ctx context.Context) error // Start gc routine based on config string settings. StartAndGC(config string) error } diff --git a/pkg/client/cache/cache_test.go b/pkg/client/cache/cache_test.go index 470c0a43..6066b72d 100644 --- a/pkg/client/cache/cache_test.go +++ b/pkg/client/cache/cache_test.go @@ -15,6 +15,7 @@ package cache import ( + "context" "os" "sync" "testing" @@ -26,19 +27,20 @@ func TestCacheIncr(t *testing.T) { if err != nil { t.Error("init err") } - //timeoutDuration := 10 * time.Second + // timeoutDuration := 10 * time.Second - bm.Put("edwardhey", 0, time.Second*20) + bm.Put(context.Background(), "edwardhey", 0, time.Second*20) wg := sync.WaitGroup{} wg.Add(10) for i := 0; i < 10; i++ { go func() { defer wg.Done() - bm.Incr("edwardhey") + bm.Incr(context.Background(), "edwardhey") }() } wg.Wait() - if bm.Get("edwardhey").(int) != 10 { + val, _ := bm.Get(context.Background(), "edwardhey") + if val.(int) != 10 { t.Error("Incr err") } } @@ -49,66 +51,66 @@ func TestCache(t *testing.T) { t.Error("init err") } timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v := bm.Get("astaxie"); v.(int) != 1 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { t.Error("get err") } time.Sleep(30 * time.Second) - if bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } - if err = bm.Incr("astaxie"); err != nil { + if err = bm.Incr(context.Background(), "astaxie"); err != nil { t.Error("Incr Error", err) } - if v := bm.Get("astaxie"); v.(int) != 2 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 2 { t.Error("get err") } - if err = bm.Decr("astaxie"); err != nil { + if err = bm.Decr(context.Background(), "astaxie"); err != nil { t.Error("Decr Error", err) } - if v := bm.Get("astaxie"); v.(int) != 1 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { t.Error("get err") } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { + bm.Delete(context.Background(), "astaxie") + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") } - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + // test GetMulti + if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v := bm.Get("astaxie"); v.(string) != "author" { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { t.Error("get err") } - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie1") { + if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { t.Error("check err") } - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) if len(vv) != 2 { t.Error("GetMulti ERROR") } @@ -126,57 +128,57 @@ func TestFileCache(t *testing.T) { t.Error("init err") } timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v := bm.Get("astaxie"); v.(int) != 1 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { t.Error("get err") } - if err = bm.Incr("astaxie"); err != nil { + if err = bm.Incr(context.Background(), "astaxie"); err != nil { t.Error("Incr Error", err) } - if v := bm.Get("astaxie"); v.(int) != 2 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 2 { t.Error("get err") } - if err = bm.Decr("astaxie"); err != nil { + if err = bm.Decr(context.Background(), "astaxie"); err != nil { t.Error("Decr Error", err) } - if v := bm.Get("astaxie"); v.(int) != 1 { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { t.Error("get err") } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { + bm.Delete(context.Background(), "astaxie") + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") } - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + // test string + if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v := bm.Get("astaxie"); v.(string) != "author" { + if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { t.Error("get err") } - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + // test GetMulti + if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie1") { + if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { t.Error("check err") } - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) if len(vv) != 2 { t.Error("GetMulti ERROR") } diff --git a/pkg/client/cache/file.go b/pkg/client/cache/file.go index 0e5c44be..dc818258 100644 --- a/pkg/client/cache/file.go +++ b/pkg/client/cache/file.go @@ -16,6 +16,7 @@ package cache import ( "bytes" + "context" "crypto/md5" "encoding/gob" "encoding/hex" @@ -28,6 +29,8 @@ import ( "reflect" "strconv" "time" + + "github.com/pkg/errors" ) // FileCacheItem is basic unit of file cache adapter which @@ -120,33 +123,44 @@ func (fc *FileCache) getCacheFileName(key string) string { // Get value from file cache. // if nonexistent or expired return an empty string. -func (fc *FileCache) Get(key string) interface{} { +func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { fileData, err := FileGetContents(fc.getCacheFileName(key)) if err != nil { - return "" + return nil, err } + var to FileCacheItem - GobDecode(fileData, &to) - if to.Expired.Before(time.Now()) { - return "" + err = GobDecode(fileData, &to) + if err != nil { + return nil, err } - return to.Data + + if to.Expired.Before(time.Now()) { + return nil, errors.New("The key is expired") + } + return to.Data, nil } // GetMulti gets values from file cache. // if nonexistent or expired return an empty string. -func (fc *FileCache) GetMulti(keys []string) []interface{} { +func (fc *FileCache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { var rc []interface{} for _, key := range keys { - rc = append(rc, fc.Get(key)) + val, err := fc.Get(context.Background(), key) + if err != nil { + rc = append(rc, err) + } else { + rc = append(rc, val) + } + } - return rc + return rc, nil } // Put value into file cache. // timeout: how long this file should be kept in ms // if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. -func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { +func (fc *FileCache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { gob.Register(val) item := FileCacheItem{Data: val} @@ -164,7 +178,7 @@ func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) err } // Delete file cache value. -func (fc *FileCache) Delete(key string) error { +func (fc *FileCache) Delete(ctx context.Context, key string) error { filename := fc.getCacheFileName(key) if ok, _ := exists(filename); ok { return os.Remove(filename) @@ -174,39 +188,39 @@ func (fc *FileCache) Delete(key string) error { // Incr increases cached int value. // fc value is saved forever unless deleted. -func (fc *FileCache) Incr(key string) error { - data := fc.Get(key) +func (fc *FileCache) Incr(ctx context.Context, key string) error { + data, _ := fc.Get(context.Background(), key) var incr int if reflect.TypeOf(data).Name() != "int" { incr = 0 } else { incr = data.(int) + 1 } - fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) + fc.Put(context.Background(), key, incr, time.Duration(fc.EmbedExpiry)) return nil } // Decr decreases cached int value. -func (fc *FileCache) Decr(key string) error { - data := fc.Get(key) +func (fc *FileCache) Decr(ctx context.Context, key string) error { + data, _ := fc.Get(context.Background(), key) var decr int if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { decr = 0 } else { decr = data.(int) - 1 } - fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) + fc.Put(context.Background(), key, decr, time.Duration(fc.EmbedExpiry)) return nil } // IsExist checks if value exists. -func (fc *FileCache) IsExist(key string) bool { +func (fc *FileCache) IsExist(ctx context.Context, key string) (bool, error) { ret, _ := exists(fc.getCacheFileName(key)) - return ret + return ret, nil } // ClearAll cleans cached files (not implemented) -func (fc *FileCache) ClearAll() error { +func (fc *FileCache) ClearAll(context.Context) error { return nil } diff --git a/pkg/client/cache/memcache/memcache.go b/pkg/client/cache/memcache/memcache.go index 60712c2f..d3b7e767 100644 --- a/pkg/client/cache/memcache/memcache.go +++ b/pkg/client/cache/memcache/memcache.go @@ -30,6 +30,7 @@ package memcache import ( + "context" "encoding/json" "errors" "strings" @@ -52,28 +53,25 @@ func NewMemCache() cache.Cache { } // Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { if rc.conn == nil { if err := rc.connectInit(); err != nil { - return err + return nil, err } } if item, err := rc.conn.Get(key); err == nil { - return item.Value + return item.Value, nil + } else { + return nil, err } - return nil } // GetMulti gets a value from a key in memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { var rv []interface{} if rc.conn == nil { if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv + return rv, err } } mv, err := rc.conn.GetMulti(keys) @@ -81,16 +79,12 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { for _, v := range mv { rv = append(rv, v.Value) } - return rv } - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv + return rv, err } // Put puts a value into memcache. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -108,7 +102,7 @@ func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { } // Delete deletes a value in memcache. -func (rc *Cache) Delete(key string) error { +func (rc *Cache) Delete(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -118,7 +112,7 @@ func (rc *Cache) Delete(key string) error { } // Incr increases counter. -func (rc *Cache) Incr(key string) error { +func (rc *Cache) Incr(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -129,7 +123,7 @@ func (rc *Cache) Incr(key string) error { } // Decr decreases counter. -func (rc *Cache) Decr(key string) error { +func (rc *Cache) Decr(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -140,18 +134,18 @@ func (rc *Cache) Decr(key string) error { } // IsExist checks if a value exists in memcache. -func (rc *Cache) IsExist(key string) bool { +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { if rc.conn == nil { if err := rc.connectInit(); err != nil { - return false + return false, err } } _, err := rc.conn.Get(key) - return err == nil + return err == nil, err } // ClearAll clears all cache in memcache. -func (rc *Cache) ClearAll() error { +func (rc *Cache) ClearAll(context.Context) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err diff --git a/pkg/client/cache/memcache/memcache_test.go b/pkg/client/cache/memcache/memcache_test.go index df2ba37f..64679671 100644 --- a/pkg/client/cache/memcache/memcache_test.go +++ b/pkg/client/cache/memcache/memcache_test.go @@ -15,15 +15,15 @@ package memcache import ( + "context" "fmt" "os" - - _ "github.com/bradfitz/gomemcache/memcache" - "strconv" "testing" "time" + _ "github.com/bradfitz/gomemcache/memcache" + "github.com/astaxie/beego/pkg/client/cache" ) @@ -39,67 +39,71 @@ func TestMemcacheCache(t *testing.T) { t.Error("init err") } timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } time.Sleep(11 * time.Second) - if bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { t.Error("set Error", err) } - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + val, _ := bm.Get(context.Background(), "astaxie") + if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { t.Error("get err") } - if err = bm.Incr("astaxie"); err != nil { + if err = bm.Incr(context.Background(), "astaxie"); err != nil { t.Error("Incr Error", err) } - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + val, _ = bm.Get(context.Background(), "astaxie") + if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 2 { t.Error("get err") } - if err = bm.Decr("astaxie"); err != nil { + if err = bm.Decr(context.Background(), "astaxie"); err != nil { t.Error("Decr Error", err) } - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + val, _ = bm.Get(context.Background(), "astaxie") + if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { t.Error("get err") } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { + bm.Delete(context.Background(), "astaxie") + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") } // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v := bm.Get("astaxie").([]byte); string(v) != "author" { + val, _ = bm.Get(context.Background(), "astaxie") + if v := val.([]byte); string(v) != "author" { t.Error("get err") } // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie1") { + if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { t.Error("check err") } - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) if len(vv) != 2 { t.Error("GetMulti ERROR") } @@ -111,7 +115,7 @@ func TestMemcacheCache(t *testing.T) { } // test clear all - if err = bm.ClearAll(); err != nil { + if err = bm.ClearAll(context.Background()); err != nil { t.Error("clear all err") } } diff --git a/pkg/client/cache/memory.go b/pkg/client/cache/memory.go index c0e35c6c..6f87ec08 100644 --- a/pkg/client/cache/memory.go +++ b/pkg/client/cache/memory.go @@ -15,6 +15,7 @@ package cache import ( + "context" "encoding/json" "errors" "sync" @@ -58,50 +59,55 @@ func NewMemoryCache() Cache { // Get returns cache from memory. // If non-existent or expired, return nil. -func (bc *MemoryCache) Get(name string) interface{} { +func (bc *MemoryCache) Get(ctx context.Context, key string) (interface{}, error) { bc.RLock() defer bc.RUnlock() - if itm, ok := bc.items[name]; ok { + if itm, ok := bc.items[key]; ok { if itm.isExpire() { - return nil + return nil, errors.New("the key is expired") } - return itm.val + return itm.val, nil } - return nil + return nil, nil } // GetMulti gets caches from memory. // If non-existent or expired, return nil. -func (bc *MemoryCache) GetMulti(names []string) []interface{} { +func (bc *MemoryCache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { var rc []interface{} - for _, name := range names { - rc = append(rc, bc.Get(name)) + for _, name := range keys { + val, err := bc.Get(context.Background(), name) + if err != nil { + rc = append(rc, err) + } else { + rc = append(rc, val) + } } - return rc + return rc, nil } // Put puts cache into memory. // If lifespan is 0, it will never overwrite this value unless restarted -func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { +func (bc *MemoryCache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { bc.Lock() defer bc.Unlock() - bc.items[name] = &MemoryItem{ - val: value, + bc.items[key] = &MemoryItem{ + val: val, createdTime: time.Now(), - lifespan: lifespan, + lifespan: timeout, } return nil } // Delete cache in memory. -func (bc *MemoryCache) Delete(name string) error { +func (bc *MemoryCache) Delete(ctx context.Context, key string) error { bc.Lock() defer bc.Unlock() - if _, ok := bc.items[name]; !ok { + if _, ok := bc.items[key]; !ok { return errors.New("key not exist") } - delete(bc.items, name) - if _, ok := bc.items[name]; ok { + delete(bc.items, key) + if _, ok := bc.items[key]; ok { return errors.New("delete key error") } return nil @@ -109,7 +115,7 @@ func (bc *MemoryCache) Delete(name string) error { // Incr increases cache counter in memory. // Supports int,int32,int64,uint,uint32,uint64. -func (bc *MemoryCache) Incr(key string) error { +func (bc *MemoryCache) Incr(ctx context.Context, key string) error { bc.Lock() defer bc.Unlock() itm, ok := bc.items[key] @@ -136,7 +142,7 @@ func (bc *MemoryCache) Incr(key string) error { } // Decr decreases counter in memory. -func (bc *MemoryCache) Decr(key string) error { +func (bc *MemoryCache) Decr(ctx context.Context, key string) error { bc.Lock() defer bc.Unlock() itm, ok := bc.items[key] @@ -175,17 +181,17 @@ func (bc *MemoryCache) Decr(key string) error { } // IsExist checks if cache exists in memory. -func (bc *MemoryCache) IsExist(name string) bool { +func (bc *MemoryCache) IsExist(ctx context.Context, key string) (bool, error) { bc.RLock() defer bc.RUnlock() - if v, ok := bc.items[name]; ok { - return !v.isExpire() + if v, ok := bc.items[key]; ok { + return !v.isExpire(), nil } - return false + return false, nil } // ClearAll deletes all cache in memory. -func (bc *MemoryCache) ClearAll() error { +func (bc *MemoryCache) ClearAll(context.Context) error { bc.Lock() defer bc.Unlock() bc.items = make(map[string]*MemoryItem) diff --git a/pkg/client/cache/redis/redis.go b/pkg/client/cache/redis/redis.go index 2cd20503..e2785297 100644 --- a/pkg/client/cache/redis/redis.go +++ b/pkg/client/cache/redis/redis.go @@ -30,6 +30,7 @@ package redis import ( + "context" "encoding/json" "errors" "fmt" @@ -83,63 +84,60 @@ func (rc *Cache) associate(originKey interface{}) string { } // Get cache from redis. -func (rc *Cache) Get(key string) interface{} { +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { if v, err := rc.do("GET", key); err == nil { - return v + return v, nil + } else { + return nil, err } - return nil } // GetMulti gets cache from redis. -func (rc *Cache) GetMulti(keys []string) []interface{} { +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { c := rc.p.Get() defer c.Close() var args []interface{} for _, key := range keys { args = append(args, rc.associate(key)) } - values, err := redis.Values(c.Do("MGET", args...)) - if err != nil { - return nil - } - return values + return redis.Values(c.Do("MGET", args...)) } // Put puts cache into redis. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) return err } // Delete deletes a key's cache in redis. -func (rc *Cache) Delete(key string) error { +func (rc *Cache) Delete(ctx context.Context, key string) error { _, err := rc.do("DEL", key) return err } // IsExist checks cache's existence in redis. -func (rc *Cache) IsExist(key string) bool { +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { v, err := redis.Bool(rc.do("EXISTS", key)) if err != nil { - return false + return false, err } - return v + return v, nil } // Incr increases a key's counter in redis. -func (rc *Cache) Incr(key string) error { +func (rc *Cache) Incr(ctx context.Context, key string) error { _, err := redis.Bool(rc.do("INCRBY", key, 1)) return err } // Decr decreases a key's counter in redis. -func (rc *Cache) Decr(key string) error { +func (rc *Cache) Decr(ctx context.Context, key string) error { _, err := redis.Bool(rc.do("INCRBY", key, -1)) return err } // ClearAll deletes all cache in the redis collection -func (rc *Cache) ClearAll() error { +func (rc *Cache) ClearAll(context.Context) error { cachedKeys, err := rc.Scan(rc.key + ":*") if err != nil { return err diff --git a/pkg/client/cache/redis/redis_test.go b/pkg/client/cache/redis/redis_test.go index dc0ca40f..f7308365 100644 --- a/pkg/client/cache/redis/redis_test.go +++ b/pkg/client/cache/redis/redis_test.go @@ -15,6 +15,7 @@ package redis import ( + "context" "fmt" "os" "testing" @@ -38,67 +39,70 @@ func TestRedisCache(t *testing.T) { t.Error("init err") } timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } time.Sleep(11 * time.Second) - if bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + val, _ := bm.Get(context.Background(), "astaxie") + if v, _ := redis.Int(val, err); v != 1 { t.Error("get err") } - if err = bm.Incr("astaxie"); err != nil { + if err = bm.Incr(context.Background(), "astaxie"); err != nil { t.Error("Incr Error", err) } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + val, _ = bm.Get(context.Background(), "astaxie") + if v, _ := redis.Int(val, err); v != 2 { t.Error("get err") } - if err = bm.Decr("astaxie"); err != nil { + if err = bm.Decr(context.Background(), "astaxie"); err != nil { t.Error("Decr Error", err) } - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + val, _ = bm.Get(context.Background(), "astaxie") + if v, _ := redis.Int(val, err); v != 1 { t.Error("get err") } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { + bm.Delete(context.Background(), "astaxie") + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("delete err") } // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie") { + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { t.Error("check err") } - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + val, _ = bm.Get(context.Background(), "astaxie") + if v, _ := redis.String(val, err); v != "author" { t.Error("get err") } // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } - if !bm.IsExist("astaxie1") { + if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { t.Error("check err") } - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) if len(vv) != 2 { t.Error("GetMulti ERROR") } @@ -110,7 +114,7 @@ func TestRedisCache(t *testing.T) { } // test clear all - if err = bm.ClearAll(); err != nil { + if err = bm.ClearAll(context.Background()); err != nil { t.Error("clear all err") } } @@ -130,7 +134,7 @@ func TestCache_Scan(t *testing.T) { } // insert all for i := 0; i < 100; i++ { - if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + if err = bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { t.Error("set Error", err) } } @@ -144,7 +148,7 @@ func TestCache_Scan(t *testing.T) { assert.Equal(t, 100, len(keys), "scan all error") // clear all - if err = bm.ClearAll(); err != nil { + if err = bm.ClearAll(context.Background()); err != nil { t.Error("clear all err") } diff --git a/pkg/client/cache/ssdb/ssdb.go b/pkg/client/cache/ssdb/ssdb.go index 10ff72b0..2e4f2815 100644 --- a/pkg/client/cache/ssdb/ssdb.go +++ b/pkg/client/cache/ssdb/ssdb.go @@ -1,6 +1,7 @@ package ssdb import ( + "context" "encoding/json" "errors" "strconv" @@ -24,29 +25,26 @@ func NewSsdbCache() cache.Cache { } // Get gets a key's value from memcache. -func (rc *Cache) Get(key string) interface{} { +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { if rc.conn == nil { if err := rc.connectInit(); err != nil { - return nil + return nil, nil } } value, err := rc.conn.Get(key) if err == nil { - return value + return value, nil } - return nil + return nil, nil } -// GetMulti gets one or keys values from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { +// GetMulti gets one or keys values from ssdb. +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { size := len(keys) var values []interface{} if rc.conn == nil { if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - values = append(values, err) - } - return values + return values, err } } res, err := rc.conn.Do("multi_get", keys) @@ -55,12 +53,12 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { for i := 1; i < resSize; i += 2 { values = append(values, res[i+1]) } - return values + return values, nil } for i := 0; i < size; i++ { values = append(values, err) } - return values + return values, nil } // DelMulti deletes one or more keys from memcache @@ -76,13 +74,13 @@ func (rc *Cache) DelMulti(keys []string) error { // Put puts value into memcache. // value: must be of type string -func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err } } - v, ok := value.(string) + v, ok := val.(string) if !ok { return errors.New("value must string") } @@ -104,7 +102,7 @@ func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error } // Delete deletes a value in memcache. -func (rc *Cache) Delete(key string) error { +func (rc *Cache) Delete(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -115,7 +113,7 @@ func (rc *Cache) Delete(key string) error { } // Incr increases a key's counter. -func (rc *Cache) Incr(key string) error { +func (rc *Cache) Incr(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -126,7 +124,7 @@ func (rc *Cache) Incr(key string) error { } // Decr decrements a key's counter. -func (rc *Cache) Decr(key string) error { +func (rc *Cache) Decr(ctx context.Context, key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -137,25 +135,25 @@ func (rc *Cache) Decr(key string) error { } // IsExist checks if a key exists in memcache. -func (rc *Cache) IsExist(key string) bool { +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { if rc.conn == nil { if err := rc.connectInit(); err != nil { - return false + return false, err } } resp, err := rc.conn.Do("exists", key) if err != nil { - return false + return false, err } if len(resp) == 2 && resp[1] == "1" { - return true + return true, nil } - return false + return false, nil } // ClearAll clears all cached items in memcache. -func (rc *Cache) ClearAll() error { +func (rc *Cache) ClearAll(context.Context) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err diff --git a/pkg/client/cache/ssdb/ssdb_test.go b/pkg/client/cache/ssdb/ssdb_test.go index bd6ede4e..f675d1ab 100644 --- a/pkg/client/cache/ssdb/ssdb_test.go +++ b/pkg/client/cache/ssdb/ssdb_test.go @@ -1,6 +1,7 @@ package ssdb import ( + "context" "fmt" "os" "strconv" @@ -23,75 +24,78 @@ func TestSsdbcacheCache(t *testing.T) { } // test put and exist - if ssdb.IsExist("ssdb") { + if res, _ := ssdb.IsExist(context.Background(), "ssdb"); res { t.Error("check err") } timeoutDuration := 10 * time.Second - //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { t.Error("set Error", err) } - if !ssdb.IsExist("ssdb") { + if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { t.Error("check err") } // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { t.Error("set Error", err) } - if v := ssdb.Get("ssdb"); v != "ssdb" { + if v, _ := ssdb.Get(context.Background(), "ssdb"); v != "ssdb" { t.Error("get Error") } - //inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + // inc/dec test done + if err = ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration); err != nil { t.Error("set Error", err) } - if err = ssdb.Incr("ssdb"); err != nil { + if err = ssdb.Incr(context.Background(), "ssdb"); err != nil { t.Error("incr Error", err) } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + val, _ := ssdb.Get(context.Background(), "ssdb") + if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { t.Error("get err") } - if err = ssdb.Decr("ssdb"); err != nil { + if err = ssdb.Decr(context.Background(), "ssdb"); err != nil { t.Error("decr error") } // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + if err = ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration); err != nil { t.Error("set Error", err) } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + + val, _ = ssdb.Get(context.Background(), "ssdb") + if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { t.Error("get err") } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { + if err := ssdb.Delete(context.Background(), "ssdb"); err == nil { + if e, _ := ssdb.IsExist(context.Background(), "ssdb"); e { t.Error("delete err") } } - //test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + // test string + if err = ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second); err != nil { t.Error("set Error", err) } - if !ssdb.IsExist("ssdb") { + if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { t.Error("check err") } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { + if v, _ := ssdb.Get(context.Background(), "ssdb"); v.(string) != "ssdb" { t.Error("get err") } - //test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + // test GetMulti done + if err = ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second); err != nil { t.Error("set Error", err) } - if !ssdb.IsExist("ssdb1") { + if res, _ := ssdb.IsExist(context.Background(), "ssdb1"); !res { t.Error("check err") } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + vv, _ := ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb1"}) if len(vv) != 2 { t.Error("getmulti error") } @@ -103,10 +107,12 @@ func TestSsdbcacheCache(t *testing.T) { } // test clear all done - if err = ssdb.ClearAll(); err != nil { + if err = ssdb.ClearAll(context.Background()); err != nil { t.Error("clear all err") } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + e1, _ := ssdb.IsExist(context.Background(), "ssdb") + e2, _ := ssdb.IsExist(context.Background(), "ssdb1") + if e1 || e2 { t.Error("check err") } } diff --git a/pkg/server/web/captcha/captcha.go b/pkg/server/web/captcha/captcha.go index 2c60f23a..36bc0fcb 100644 --- a/pkg/server/web/captcha/captcha.go +++ b/pkg/server/web/captcha/captcha.go @@ -59,6 +59,7 @@ package captcha import ( + context2 "context" "fmt" "html/template" "net/http" @@ -137,14 +138,15 @@ func (c *Captcha) Handler(ctx *context.Context) { if len(ctx.Input.Query("reload")) > 0 { chars = c.genRandChars() - if err := c.store.Put(key, chars, c.Expiration); err != nil { + if err := c.store.Put(context2.Background(), key, chars, c.Expiration); err != nil { ctx.Output.SetStatus(500) ctx.WriteString("captcha reload error") logs.Error("Reload Create Captcha Error:", err) return } } else { - if v, ok := c.store.Get(key).([]byte); ok { + val, _ := c.store.Get(context2.Background(), key) + if v, ok := val.([]byte); ok { chars = v } else { ctx.Output.SetStatus(404) @@ -183,7 +185,7 @@ func (c *Captcha) CreateCaptcha() (string, error) { chars := c.genRandChars() // save to store - if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + if err := c.store.Put(context2.Background(), c.key(id), chars, c.Expiration); err != nil { return "", err } @@ -205,8 +207,8 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { var chars []byte key := c.key(id) - - if v, ok := c.store.Get(key).([]byte); ok { + val, _ := c.store.Get(context2.Background(), key) + if v, ok := val.([]byte); ok { chars = v } else { return @@ -214,7 +216,7 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { defer func() { // finally remove it - c.store.Delete(key) + c.store.Delete(context2.Background(), key) }() if len(chars) != len(challenge) { @@ -271,9 +273,9 @@ func NewWithFilter(urlPrefix string, store Storage) *Captcha { type Storage interface { // Get a cached value by key. - Get(key string) interface{} + Get(ctx context2.Context, key string) (interface{}, error) // Set a cached value with key and expire time. - Put(key string, val interface{}, timeout time.Duration) error + Put(ctx context2.Context, key string, val interface{}, timeout time.Duration) error // Delete cached value by key. - Delete(key string) error + Delete(ctx context2.Context, key string) error } From 4dc694411f3a4537c03aa95f69105e546355b516 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 5 Oct 2020 00:16:58 +0800 Subject: [PATCH 180/207] fix deadlock in task module --- pkg/task/task.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pkg/task/task.go b/pkg/task/task.go index e3a8bba4..4835ad24 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -451,6 +451,11 @@ func run() { taskLock.Unlock() continue case <-stop: + taskLock.Lock() + if isstart { + isstart = false + } + taskLock.Unlock() return } } @@ -458,13 +463,7 @@ func run() { // StopTask stop all tasks func StopTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - isstart = false - stop <- true - } - + stop <- true } // AddTask add task with name From f1cca45d8d2235daad5da3f25e74274887ccb290 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 5 Oct 2020 01:31:27 +0800 Subject: [PATCH 181/207] fix deadlock about changed sign --- pkg/task/task.go | 16 ++++++++++++++-- pkg/task/task_test.go | 19 ++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pkg/task/task.go b/pkg/task/task.go index 4835ad24..c8228fd2 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -468,21 +468,33 @@ func StopTask() { // AddTask add task with name func AddTask(taskname string, t Tasker) { + isChanged := false taskLock.Lock() - defer taskLock.Unlock() t.SetNext(nil, time.Now().Local()) AdminTaskList[taskname] = t if isstart { + isChanged = true + } + taskLock.Unlock() + + if isChanged { changed <- true } + } // DeleteTask delete task with name func DeleteTask(taskname string) { + isChanged := false + taskLock.Lock() - defer taskLock.Unlock() delete(AdminTaskList, taskname) if isstart { + isChanged = true + } + taskLock.Unlock() + + if isChanged { changed <- true } } diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 488729dc..f58e374b 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -36,7 +36,24 @@ func TestParse(t *testing.T) { } AddTask("taska", tk) StartTask() - time.Sleep(6 * time.Second) + time.Sleep(3 * time.Second) + StopTask() +} + +func TestModifyTaskListAfterRunning(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) + err := tk.Run(nil) + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + DeleteTask("taska") + AddTask("taska1", tk) + time.Sleep(3 * time.Second) StopTask() } From 70cca5e2981dc1b6a16377a823b6f912e3884f9e Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 5 Oct 2020 10:13:29 +0800 Subject: [PATCH 182/207] make code testable in task module --- pkg/task/govenor_command.go | 6 +- pkg/task/task.go | 118 ++++++++++++++++++++++++------------ pkg/task/task_test.go | 29 +++++---- 3 files changed, 99 insertions(+), 54 deletions(-) diff --git a/pkg/task/govenor_command.go b/pkg/task/govenor_command.go index fff08374..be351dc1 100644 --- a/pkg/task/govenor_command.go +++ b/pkg/task/govenor_command.go @@ -28,8 +28,8 @@ type listTaskCommand struct { } func (l *listTaskCommand) Execute(params ...interface{}) *governor.Result { - resultList := make([][]string, 0, len(AdminTaskList)) - for tname, tk := range AdminTaskList { + resultList := make([][]string, 0, len(globalTaskManager.adminTaskList)) + for tname, tk := range globalTaskManager.adminTaskList { result := []string{ template.HTMLEscapeString(tname), template.HTMLEscapeString(tk.GetSpec(nil)), @@ -65,7 +65,7 @@ func (r *runTaskCommand) Execute(params ...interface{}) *governor.Result { } } - if t, ok := AdminTaskList[tn]; ok { + if t, ok := globalTaskManager.adminTaskList[tn]; ok { err := t.Run(context.Background()) if err != nil { return &governor.Result{ diff --git a/pkg/task/task.go b/pkg/task/task.go index c8228fd2..5faa2fd8 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -31,13 +31,28 @@ type bounds struct { names map[string]uint } -// The bounds for each field. -var ( - AdminTaskList map[string]Tasker +type taskManager struct { + adminTaskList map[string]Tasker taskLock sync.RWMutex stop chan bool changed chan bool isstart bool +} + +func newTaskManager()*taskManager{ + return &taskManager{ + adminTaskList: make(map[string]Tasker), + taskLock: sync.RWMutex{}, + stop: make(chan bool), + changed: make(chan bool), + isstart: false, + } +} + +// The bounds for each field. +var ( + globalTaskManager *taskManager + seconds = bounds{0, 59, nil} minutes = bounds{0, 59, nil} hours = bounds{0, 23, nil} @@ -398,32 +413,53 @@ func dayMatches(s *Schedule, t time.Time) bool { // StartTask start all tasks func StartTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { + globalTaskManager.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + globalTaskManager.StopTask() +} + +// AddTask add task with name +func AddTask(taskName string, t Tasker) { + globalTaskManager.AddTask(taskName, t) +} + +// DeleteTask delete task with name +func DeleteTask(taskName string) { + globalTaskManager.DeleteTask(taskName) +} + + +// StartTask start all tasks +func (m *taskManager) StartTask() { + m.taskLock.Lock() + defer m.taskLock.Unlock() + if m.isstart { // If already started, no need to start another goroutine. return } - isstart = true + m.isstart = true registerCommands() - go run() + go m.run() } -func run() { +func(m *taskManager) run() { now := time.Now().Local() - for _, t := range AdminTaskList { + for _, t := range m.adminTaskList { t.SetNext(nil, now) } for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing - taskLock.RLock() - sortList := NewMapSorter(AdminTaskList) - taskLock.RUnlock() + m.taskLock.RLock() + sortList := NewMapSorter(m.adminTaskList) + m.taskLock.RUnlock() sortList.Sort() var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext(context.Background()).IsZero() { + if len(m.adminTaskList) == 0 || sortList.Vals[0].GetNext(context.Background()).IsZero() { // If there are no entries yet, just sleep - it still handles new entries // and stop requests. effective = now.AddDate(10, 0, 0) @@ -442,60 +478,64 @@ func run() { e.SetNext(nil, effective) } continue - case <-changed: + case <-m.changed: now = time.Now().Local() - taskLock.Lock() - for _, t := range AdminTaskList { + m.taskLock.Lock() + for _, t := range m.adminTaskList { t.SetNext(nil, now) } - taskLock.Unlock() + m.taskLock.Unlock() continue - case <-stop: - taskLock.Lock() - if isstart { - isstart = false + case <-m.stop: + m.taskLock.Lock() + if m.isstart { + m.isstart = false } - taskLock.Unlock() + m.taskLock.Unlock() return } } } // StopTask stop all tasks -func StopTask() { - stop <- true +func(m *taskManager) StopTask() { + go func() { + m.stop <- true + }() } // AddTask add task with name -func AddTask(taskname string, t Tasker) { +func (m *taskManager)AddTask(taskname string, t Tasker) { isChanged := false - taskLock.Lock() + m.taskLock.Lock() t.SetNext(nil, time.Now().Local()) - AdminTaskList[taskname] = t - if isstart { + m.adminTaskList[taskname] = t + if m.isstart { isChanged = true } - taskLock.Unlock() + m.taskLock.Unlock() if isChanged { - changed <- true + go func() { + m.changed <- true + }() } } // DeleteTask delete task with name -func DeleteTask(taskname string) { +func(m *taskManager) DeleteTask(taskname string) { isChanged := false - taskLock.Lock() - delete(AdminTaskList, taskname) - if isstart { + m.taskLock.Lock() + delete(m.adminTaskList, taskname) + if m.isstart { isChanged = true } - taskLock.Unlock() + m.taskLock.Unlock() if isChanged { - changed <- true + m.changed <- true } } @@ -648,7 +688,5 @@ func all(r bounds) uint64 { } func init() { - AdminTaskList = make(map[string]Tasker) - stop = make(chan bool) - changed = make(chan bool) + globalTaskManager = newTaskManager() } diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index f58e374b..9a74ff24 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -41,7 +41,8 @@ func TestParse(t *testing.T) { } func TestModifyTaskListAfterRunning(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + m := newTaskManager() + tk := NewTask("taskb", "0/30 * * * * *", func(ctx context.Context) error { fmt.Println("hello world") return nil }) @@ -49,26 +50,32 @@ func TestModifyTaskListAfterRunning(t *testing.T) { if err != nil { t.Fatal(err) } - AddTask("taska", tk) - StartTask() - DeleteTask("taska") - AddTask("taska1", tk) + m.AddTask("taskb", tk) + m.StartTask() + go func() { + m.DeleteTask("taskb") + }() + go func() { + m.AddTask("taskb1", tk) + }() + time.Sleep(3 * time.Second) - StopTask() + m.StopTask() } func TestSpec(t *testing.T) { + m := newTaskManager() wg := &sync.WaitGroup{} wg.Add(2) tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) - AddTask("tk1", tk1) - AddTask("tk2", tk2) - AddTask("tk3", tk3) - StartTask() - defer StopTask() + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.AddTask("tk3", tk3) + m.StartTask() + defer m.StopTask() select { case <-time.After(200 * time.Second): From b838683731bbe683cff94b7c99dc01e6c4b059be Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 5 Oct 2020 10:33:23 +0800 Subject: [PATCH 183/207] add api for testing --- pkg/adapter/toolbox/task.go | 5 +++++ pkg/adapter/toolbox/task_test.go | 4 ++++ pkg/task/task.go | 36 ++++++++++++++++++++++++-------- pkg/task/task_test.go | 10 ++++++--- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pkg/adapter/toolbox/task.go b/pkg/adapter/toolbox/task.go index 2a6d9aa6..5b2fa14c 100644 --- a/pkg/adapter/toolbox/task.go +++ b/pkg/adapter/toolbox/task.go @@ -212,6 +212,11 @@ func DeleteTask(taskname string) { task.DeleteTask(taskname) } +// ClearTask clear all tasks +func ClearTask() { + task.ClearTask() +} + // MapSorter sort map for tasker type MapSorter task.MapSorter diff --git a/pkg/adapter/toolbox/task_test.go b/pkg/adapter/toolbox/task_test.go index 596bc9c5..994c4976 100644 --- a/pkg/adapter/toolbox/task_test.go +++ b/pkg/adapter/toolbox/task_test.go @@ -22,6 +22,8 @@ import ( ) func TestParse(t *testing.T) { + defer ClearTask() + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) err := tk.Run() if err != nil { @@ -34,6 +36,8 @@ func TestParse(t *testing.T) { } func TestSpec(t *testing.T) { + defer ClearTask() + wg := &sync.WaitGroup{} wg.Add(2) tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) diff --git a/pkg/task/task.go b/pkg/task/task.go index 5faa2fd8..a781e47a 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -36,7 +36,7 @@ type taskManager struct { taskLock sync.RWMutex stop chan bool changed chan bool - isstart bool + started bool } func newTaskManager()*taskManager{ @@ -45,7 +45,7 @@ func newTaskManager()*taskManager{ taskLock: sync.RWMutex{}, stop: make(chan bool), changed: make(chan bool), - isstart: false, + started: false, } } @@ -431,16 +431,21 @@ func DeleteTask(taskName string) { globalTaskManager.DeleteTask(taskName) } +// ClearTask clear all tasks +func ClearTask() { + globalTaskManager.ClearTask() +} + // StartTask start all tasks func (m *taskManager) StartTask() { m.taskLock.Lock() defer m.taskLock.Unlock() - if m.isstart { + if m.started { // If already started, no need to start another goroutine. return } - m.isstart = true + m.started = true registerCommands() go m.run() @@ -488,8 +493,8 @@ func(m *taskManager) run() { continue case <-m.stop: m.taskLock.Lock() - if m.isstart { - m.isstart = false + if m.started { + m.started = false } m.taskLock.Unlock() return @@ -510,7 +515,7 @@ func (m *taskManager)AddTask(taskname string, t Tasker) { m.taskLock.Lock() t.SetNext(nil, time.Now().Local()) m.adminTaskList[taskname] = t - if m.isstart { + if m.started { isChanged = true } m.taskLock.Unlock() @@ -529,16 +534,29 @@ func(m *taskManager) DeleteTask(taskname string) { m.taskLock.Lock() delete(m.adminTaskList, taskname) - if m.isstart { + if m.started { isChanged = true } m.taskLock.Unlock() if isChanged { - m.changed <- true + go func() { + m.changed <- true + }() } } +// ClearTask clear all tasks +func(m *taskManager) ClearTask() { + m.taskLock.Lock() + m.adminTaskList = make(map[string]Tasker) + m.taskLock.Unlock() + + go func() { + m.changed <- true + }() +} + // MapSorter sort map for tasker type MapSorter struct { Keys []string diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 9a74ff24..2cb807ce 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -26,6 +26,8 @@ import ( ) func TestParse(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { fmt.Println("hello world") return nil @@ -34,14 +36,15 @@ func TestParse(t *testing.T) { if err != nil { t.Fatal(err) } - AddTask("taska", tk) - StartTask() + m.AddTask("taska", tk) + m.StartTask() time.Sleep(3 * time.Second) - StopTask() + m.StopTask() } func TestModifyTaskListAfterRunning(t *testing.T) { m := newTaskManager() + defer m.ClearTask() tk := NewTask("taskb", "0/30 * * * * *", func(ctx context.Context) error { fmt.Println("hello world") return nil @@ -65,6 +68,7 @@ func TestModifyTaskListAfterRunning(t *testing.T) { func TestSpec(t *testing.T) { m := newTaskManager() + defer m.ClearTask() wg := &sync.WaitGroup{} wg.Add(2) tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) From c435d231ab687cf80afb5af3f03f6e71eea49ab2 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 5 Oct 2020 10:38:14 +0800 Subject: [PATCH 184/207] complete check --- pkg/task/task.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pkg/task/task.go b/pkg/task/task.go index a781e47a..e76706e3 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -548,13 +548,20 @@ func(m *taskManager) DeleteTask(taskname string) { // ClearTask clear all tasks func(m *taskManager) ClearTask() { + isChanged := false + m.taskLock.Lock() m.adminTaskList = make(map[string]Tasker) + if m.started { + isChanged = true + } m.taskLock.Unlock() - go func() { - m.changed <- true - }() + if isChanged { + go func() { + m.changed <- true + }() + } } // MapSorter sort map for tasker From f9bef68aa9a90657ac1da050f9acfddf6619b4dd Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 4 Oct 2020 23:05:37 +0800 Subject: [PATCH 185/207] Adapter: cache API --- pkg/adapter/cache/cache.go | 24 ++- pkg/adapter/cache/cache_adapter.go | 117 ++++++++++++ pkg/adapter/cache/cache_test.go | 191 ++++++++++++++++++++ pkg/adapter/cache/conv.go | 44 +++++ pkg/adapter/cache/conv_test.go | 143 +++++++++++++++ pkg/adapter/cache/file.go | 30 +++ pkg/adapter/cache/memcache/memcache.go | 44 +++++ pkg/adapter/cache/memcache/memcache_test.go | 114 ++++++++++++ pkg/adapter/cache/memory.go | 28 +++ pkg/adapter/cache/redis/redis.go | 49 +++++ pkg/adapter/cache/redis/redis_test.go | 135 ++++++++++++++ pkg/adapter/cache/ssdb/ssdb.go | 15 ++ pkg/adapter/cache/ssdb/ssdb_test.go | 111 ++++++++++++ pkg/adapter/utils/captcha/captcha.go | 4 +- 14 files changed, 1044 insertions(+), 5 deletions(-) create mode 100644 pkg/adapter/cache/cache_adapter.go create mode 100644 pkg/adapter/cache/cache_test.go create mode 100644 pkg/adapter/cache/conv.go create mode 100644 pkg/adapter/cache/conv_test.go create mode 100644 pkg/adapter/cache/file.go create mode 100644 pkg/adapter/cache/memcache/memcache.go create mode 100644 pkg/adapter/cache/memcache/memcache_test.go create mode 100644 pkg/adapter/cache/memory.go create mode 100644 pkg/adapter/cache/redis/redis.go create mode 100644 pkg/adapter/cache/redis/redis_test.go create mode 100644 pkg/adapter/cache/ssdb/ssdb.go create mode 100644 pkg/adapter/cache/ssdb/ssdb_test.go diff --git a/pkg/adapter/cache/cache.go b/pkg/adapter/cache/cache.go index 21bb9141..82585c4e 100644 --- a/pkg/adapter/cache/cache.go +++ b/pkg/adapter/cache/cache.go @@ -33,8 +33,7 @@ package cache import ( "fmt" - - "github.com/astaxie/beego/pkg/client/cache" + "time" ) // Cache interface contains all behaviors for cache adapter. @@ -47,7 +46,26 @@ import ( // c.Incr("counter") // now is 1 // c.Incr("counter") // now is 2 // count := c.Get("counter").(int) -type Cache cache.Cache +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} // Instance is a function create a new Cache Instance type Instance func() Cache diff --git a/pkg/adapter/cache/cache_adapter.go b/pkg/adapter/cache/cache_adapter.go new file mode 100644 index 00000000..f1441ac8 --- /dev/null +++ b/pkg/adapter/cache/cache_adapter.go @@ -0,0 +1,117 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "time" + + "github.com/astaxie/beego/pkg/client/cache" +) + +type newToOldCacheAdapter struct { + delegate cache.Cache +} + +func (c *newToOldCacheAdapter) Get(key string) interface{} { + res, _ := c.delegate.Get(context.Background(), key) + return res +} + +func (c *newToOldCacheAdapter) GetMulti(keys []string) []interface{} { + res, _ := c.delegate.GetMulti(context.Background(), keys) + return res +} + +func (c *newToOldCacheAdapter) Put(key string, val interface{}, timeout time.Duration) error { + return c.delegate.Put(context.Background(), key, val, timeout) +} + +func (c *newToOldCacheAdapter) Delete(key string) error { + return c.delegate.Delete(context.Background(), key) +} + +func (c *newToOldCacheAdapter) Incr(key string) error { + return c.delegate.Incr(context.Background(), key) +} + +func (c *newToOldCacheAdapter) Decr(key string) error { + return c.delegate.Decr(context.Background(), key) +} + +func (c *newToOldCacheAdapter) IsExist(key string) bool { + res, err := c.delegate.IsExist(context.Background(), key) + return res && err == nil +} + +func (c *newToOldCacheAdapter) ClearAll() error { + return c.delegate.ClearAll(context.Background()) +} + +func (c *newToOldCacheAdapter) StartAndGC(config string) error { + return c.delegate.StartAndGC(config) +} + +func CreateNewToOldCacheAdapter(delegate cache.Cache) Cache { + return &newToOldCacheAdapter{ + delegate: delegate, + } +} + +type oldToNewCacheAdapter struct { + old Cache +} + +func (o *oldToNewCacheAdapter) Get(ctx context.Context, key string) (interface{}, error) { + return o.old.Get(key), nil +} + +func (o *oldToNewCacheAdapter) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + return o.old.GetMulti(keys), nil +} + +func (o *oldToNewCacheAdapter) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + return o.old.Put(key, val, timeout) +} + +func (o *oldToNewCacheAdapter) Delete(ctx context.Context, key string) error { + return o.old.Delete(key) +} + +func (o *oldToNewCacheAdapter) Incr(ctx context.Context, key string) error { + return o.old.Incr(key) +} + +func (o *oldToNewCacheAdapter) Decr(ctx context.Context, key string) error { + return o.old.Decr(key) +} + +func (o *oldToNewCacheAdapter) IsExist(ctx context.Context, key string) (bool, error) { + return o.old.IsExist(key), nil +} + +func (o *oldToNewCacheAdapter) ClearAll(ctx context.Context) error { + return o.old.ClearAll() +} + +func (o *oldToNewCacheAdapter) StartAndGC(config string) error { + return o.old.StartAndGC(config) +} + +func CreateOldToNewAdapter(old Cache) cache.Cache { + return &oldToNewCacheAdapter{ + old: old, + } +} diff --git a/pkg/adapter/cache/cache_test.go b/pkg/adapter/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/adapter/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/adapter/cache/conv.go b/pkg/adapter/cache/conv.go new file mode 100644 index 00000000..d46cc31c --- /dev/null +++ b/pkg/adapter/cache/conv.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/astaxie/beego/pkg/client/cache" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + return cache.GetString(v) +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + return cache.GetInt(v) +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + return cache.GetInt64(v) +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + return cache.GetFloat64(v) +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + return cache.GetBool(v) +} diff --git a/pkg/adapter/cache/conv_test.go b/pkg/adapter/cache/conv_test.go new file mode 100644 index 00000000..b90e224a --- /dev/null +++ b/pkg/adapter/cache/conv_test.go @@ -0,0 +1,143 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" +) + +func TestGetString(t *testing.T) { + var t1 = "test1" + if "test1" != GetString(t1) { + t.Error("get string from string error") + } + var t2 = []byte("test2") + if "test2" != GetString(t2) { + t.Error("get string from byte array error") + } + var t3 = 1 + if "1" != GetString(t3) { + t.Error("get string from int error") + } + var t4 int64 = 1 + if "1" != GetString(t4) { + t.Error("get string from int64 error") + } + var t5 = 1.1 + if "1.1" != GetString(t5) { + t.Error("get string from float64 error") + } + + if "" != GetString(nil) { + t.Error("get string from nil error") + } +} + +func TestGetInt(t *testing.T) { + var t1 = 1 + if 1 != GetInt(t1) { + t.Error("get int from int error") + } + var t2 int32 = 32 + if 32 != GetInt(t2) { + t.Error("get int from int32 error") + } + var t3 int64 = 64 + if 64 != GetInt(t3) { + t.Error("get int from int64 error") + } + var t4 = "128" + if 128 != GetInt(t4) { + t.Error("get int from num string error") + } + if 0 != GetInt(nil) { + t.Error("get int from nil error") + } +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + var t1 = 1 + if i != GetInt64(t1) { + t.Error("get int64 from int error") + } + var t2 int32 = 1 + if i != GetInt64(t2) { + t.Error("get int64 from int32 error") + } + var t3 int64 = 1 + if i != GetInt64(t3) { + t.Error("get int64 from int64 error") + } + var t4 = "1" + if i != GetInt64(t4) { + t.Error("get int64 from num string error") + } + if 0 != GetInt64(nil) { + t.Error("get int64 from nil") + } +} + +func TestGetFloat64(t *testing.T) { + var f = 1.11 + var t1 float32 = 1.11 + if f != GetFloat64(t1) { + t.Error("get float64 from float32 error") + } + var t2 = 1.11 + if f != GetFloat64(t2) { + t.Error("get float64 from float64 error") + } + var t3 = "1.11" + if f != GetFloat64(t3) { + t.Error("get float64 from string error") + } + + var f2 float64 = 1 + var t4 = 1 + if f2 != GetFloat64(t4) { + t.Error("get float64 from int error") + } + + if 0 != GetFloat64(nil) { + t.Error("get float64 from nil error") + } +} + +func TestGetBool(t *testing.T) { + var t1 = true + if !GetBool(t1) { + t.Error("get bool from bool error") + } + var t2 = "true" + if !GetBool(t2) { + t.Error("get bool from string error") + } + if GetBool(nil) { + t.Error("get bool from nil error") + } +} + +func byteArrayEquals(a []byte, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/pkg/adapter/cache/file.go b/pkg/adapter/cache/file.go new file mode 100644 index 00000000..04598d27 --- /dev/null +++ b/pkg/adapter/cache/file.go @@ -0,0 +1,30 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/astaxie/beego/pkg/client/cache" +) + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return CreateNewToOldCacheAdapter(cache.NewFileCache()) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/pkg/adapter/cache/memcache/memcache.go b/pkg/adapter/cache/memcache/memcache.go new file mode 100644 index 00000000..f2acffca --- /dev/null +++ b/pkg/adapter/cache/memcache/memcache.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/memcache" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package memcache + +import ( + "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/pkg/client/cache/memcache" +) + +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return cache.CreateNewToOldCacheAdapter(memcache.NewMemCache()) +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/pkg/adapter/cache/memcache/memcache_test.go b/pkg/adapter/cache/memcache/memcache_test.go new file mode 100644 index 00000000..e6e605a4 --- /dev/null +++ b/pkg/adapter/cache/memcache/memcache_test.go @@ -0,0 +1,114 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + "fmt" + "os" + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/pkg/adapter/cache" +) + +func TestMemcacheCache(t *testing.T) { + + addr := os.Getenv("MEMCACHE_ADDR") + if addr == "" { + addr = "127.0.0.1:11211" + } + + bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + // test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie").([]byte); string(v) != "author" { + t.Error("get err") + } + + // test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} diff --git a/pkg/adapter/cache/memory.go b/pkg/adapter/cache/memory.go new file mode 100644 index 00000000..2d734bc0 --- /dev/null +++ b/pkg/adapter/cache/memory.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/astaxie/beego/pkg/client/cache" +) + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + return CreateNewToOldCacheAdapter(cache.NewMemoryCache()) +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/pkg/adapter/cache/redis/redis.go b/pkg/adapter/cache/redis/redis.go new file mode 100644 index 00000000..3aeb8691 --- /dev/null +++ b/pkg/adapter/cache/redis/redis.go @@ -0,0 +1,49 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/redis" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package redis + +import ( + "github.com/astaxie/beego/pkg/adapter/cache" + redis2 "github.com/astaxie/beego/pkg/client/cache/redis" +) + +var ( + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" +) + +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return cache.CreateNewToOldCacheAdapter(redis2.NewRedisCache()) +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/pkg/adapter/cache/redis/redis_test.go b/pkg/adapter/cache/redis/redis_test.go new file mode 100644 index 00000000..165ad0a7 --- /dev/null +++ b/pkg/adapter/cache/redis/redis_test.go @@ -0,0 +1,135 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/astaxie/beego/pkg/adapter/cache" +) + +func TestRedisCache(t *testing.T) { + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "127.0.0.1:6379" + } + + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[0], nil); v != "author" { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[1], nil); v != "author1" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} + +func TestCache_Scan(t *testing.T) { + timeoutDuration := 10 * time.Second + // init + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + // insert all + for i := 0; i < 10000; i++ { + if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + t.Error("set Error", err) + } + } + + // clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } + +} diff --git a/pkg/adapter/cache/ssdb/ssdb.go b/pkg/adapter/cache/ssdb/ssdb.go new file mode 100644 index 00000000..9a252b55 --- /dev/null +++ b/pkg/adapter/cache/ssdb/ssdb.go @@ -0,0 +1,15 @@ +package ssdb + +import ( + "github.com/astaxie/beego/pkg/adapter/cache" + ssdb2 "github.com/astaxie/beego/pkg/client/cache/ssdb" +) + +// NewSsdbCache create new ssdb adapter. +func NewSsdbCache() cache.Cache { + return cache.CreateNewToOldCacheAdapter(ssdb2.NewSsdbCache()) +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/pkg/adapter/cache/ssdb/ssdb_test.go b/pkg/adapter/cache/ssdb/ssdb_test.go new file mode 100644 index 00000000..0f9dabba --- /dev/null +++ b/pkg/adapter/cache/ssdb/ssdb_test.go @@ -0,0 +1,111 @@ +package ssdb + +import ( + "fmt" + "os" + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/pkg/adapter/cache" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdbAddr := os.Getenv("SSDB_ADDR") + if ssdbAddr == "" { + ssdbAddr = "127.0.0.1:8888" + } + + ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) + if err != nil { + t.Error("init err") + } + + // test put and exist + if ssdb.IsExist("ssdb") { + t.Error("check err") + } + timeoutDuration := 10 * time.Second + //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + + // Get test done + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v := ssdb.Get("ssdb"); v != "ssdb" { + t.Error("get Error") + } + + //inc/dec test done + if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if err = ssdb.Incr("ssdb"); err != nil { + t.Error("incr Error", err) + } + + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + + if err = ssdb.Decr("ssdb"); err != nil { + t.Error("decr error") + } + + // test del + if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + if err := ssdb.Delete("ssdb"); err == nil { + if ssdb.IsExist("ssdb") { + t.Error("delete err") + } + } + + //test string + if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + if v := ssdb.Get("ssdb").(string); v != "ssdb" { + t.Error("get err") + } + + //test GetMulti done + if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb1") { + t.Error("check err") + } + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + if len(vv) != 2 { + t.Error("getmulti error") + } + if vv[0].(string) != "ssdb" { + t.Error("getmulti error") + } + if vv[1].(string) != "ssdb1" { + t.Error("getmulti error") + } + + // test clear all done + if err = ssdb.ClearAll(); err != nil { + t.Error("clear all err") + } + if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + t.Error("check err") + } +} diff --git a/pkg/adapter/utils/captcha/captcha.go b/pkg/adapter/utils/captcha/captcha.go index faadc8bf..aad3994b 100644 --- a/pkg/adapter/utils/captcha/captcha.go +++ b/pkg/adapter/utils/captcha/captcha.go @@ -114,11 +114,11 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { // NewCaptcha create a new captcha.Captcha func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { - return (*Captcha)(captcha.NewCaptcha(urlPrefix, store)) + return (*Captcha)(captcha.NewCaptcha(urlPrefix, cache.CreateOldToNewAdapter(store))) } // NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image // and add a template func for output html func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { - return (*Captcha)(captcha.NewWithFilter(urlPrefix, store)) + return (*Captcha)(captcha.NewWithFilter(urlPrefix, cache.CreateOldToNewAdapter(store))) } From 48e98482f79949d31687979c1b7f11d80de2bc9c Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 5 Oct 2020 18:13:26 +0800 Subject: [PATCH 186/207] rename infrastructure to core --- pkg/adapter/admin.go | 2 +- pkg/adapter/config.go | 2 +- pkg/adapter/config/adapter.go | 2 +- pkg/adapter/config/config.go | 2 +- pkg/adapter/config/env/env.go | 2 +- pkg/adapter/config/fake.go | 2 +- pkg/adapter/config/json.go | 2 +- pkg/adapter/config/xml/xml.go | 2 +- pkg/adapter/config/yaml/yaml.go | 2 +- pkg/adapter/log.go | 4 ++-- pkg/adapter/metric/prometheus.go | 2 +- pkg/adapter/orm/orm.go | 2 +- .../session/couchbase/sess_couchbase.go | 2 +- pkg/adapter/session/ledis/ledis_session.go | 2 +- pkg/adapter/session/memcache/sess_memcache.go | 2 +- pkg/adapter/session/mysql/sess_mysql.go | 2 +- .../session/postgres/sess_postgresql.go | 2 +- pkg/adapter/session/provider_adapter.go | 2 +- pkg/adapter/session/redis/sess_redis.go | 2 +- .../session/redis_cluster/redis_cluster.go | 2 +- .../redis_sentinel/sess_redis_sentinel.go | 2 +- pkg/adapter/session/sess_cookie.go | 2 +- pkg/adapter/session/sess_file.go | 2 +- pkg/adapter/session/sess_mem.go | 2 +- pkg/adapter/session/sess_utils.go | 2 +- pkg/adapter/session/session.go | 2 +- pkg/adapter/session/ssdb/sess_ssdb.go | 2 +- pkg/adapter/session/store_adapter.go | 2 +- pkg/adapter/toolbox/healthcheck.go | 2 +- pkg/adapter/toolbox/profile.go | 2 +- pkg/adapter/utils/caller.go | 2 +- pkg/adapter/utils/debug.go | 2 +- pkg/adapter/utils/file.go | 2 +- pkg/adapter/utils/mail.go | 2 +- pkg/adapter/utils/pagination/paginator.go | 2 +- pkg/adapter/utils/rand.go | 2 +- pkg/adapter/utils/safemap.go | 2 +- pkg/adapter/utils/slice.go | 2 +- pkg/adapter/utils/utils.go | 2 +- pkg/adapter/validation/util.go | 2 +- pkg/adapter/validation/validation.go | 2 +- pkg/adapter/validation/validators.go | 2 +- pkg/client/orm/do_nothing_orm.go | 2 +- .../orm/filter/bean/default_value_filter.go | 4 ++-- pkg/client/orm/filter_orm_decorator.go | 2 +- pkg/client/orm/filter_orm_decorator_test.go | 2 +- pkg/client/orm/hints/db_hints.go | 2 +- pkg/client/orm/migration/ddl.go | 2 +- pkg/client/orm/migration/migration.go | 2 +- pkg/client/orm/orm.go | 4 ++-- pkg/client/orm/types.go | 2 +- pkg/{infrastructure => core}/bean/context.go | 0 pkg/{infrastructure => core}/bean/doc.go | 0 pkg/{infrastructure => core}/bean/factory.go | 0 pkg/{infrastructure => core}/bean/metadata.go | 0 .../bean/tag_auto_wire_bean_factory.go | 2 +- .../bean/tag_auto_wire_bean_factory_test.go | 0 .../bean/time_type_adapter.go | 0 .../bean/time_type_adapter_test.go | 0 .../bean/type_adapter.go | 0 .../config/base_config_test.go | 0 pkg/{infrastructure => core}/config/config.go | 0 .../config/config_test.go | 0 .../config/env/env.go | 2 +- .../config/env/env_test.go | 0 .../config/etcd/config.go | 4 ++-- .../config/etcd/config_test.go | 0 pkg/{infrastructure => core}/config/fake.go | 0 pkg/{infrastructure => core}/config/ini.go | 0 .../config/ini_test.go | 0 .../config/json/json.go | 4 ++-- .../config/json/json_test.go | 2 +- .../config/xml/xml.go | 4 ++-- .../config/xml/xml_test.go | 2 +- .../config/yaml/yaml.go | 4 ++-- .../config/yaml/yaml_test.go | 2 +- .../governor/command.go | 0 .../governor/healthcheck.go | 0 .../governor/profile.go | 2 +- .../governor/profile_test.go | 0 pkg/{infrastructure => core}/logs/README.md | 0 .../logs/access_log.go | 0 .../logs/access_log_test.go | 0 .../logs/alils/alils.go | 2 +- .../logs/alils/config.go | 0 .../logs/alils/log.pb.go | 0 .../logs/alils/log_config.go | 0 .../logs/alils/log_project.go | 0 .../logs/alils/log_store.go | 0 .../logs/alils/machine_group.go | 0 .../logs/alils/request.go | 0 .../logs/alils/signature.go | 0 pkg/{infrastructure => core}/logs/conn.go | 0 .../logs/conn_test.go | 0 pkg/{infrastructure => core}/logs/console.go | 0 .../logs/console_test.go | 0 pkg/{infrastructure => core}/logs/es/es.go | 2 +- pkg/{infrastructure => core}/logs/es/index.go | 2 +- .../logs/es/index_test.go | 2 +- pkg/{infrastructure => core}/logs/file.go | 0 .../logs/file_test.go | 0 .../logs/formatter.go | 0 .../logs/formatter_test.go | 0 pkg/{infrastructure => core}/logs/jianliao.go | 0 .../logs/jianliao_test.go | 0 pkg/{infrastructure => core}/logs/log.go | 0 pkg/{infrastructure => core}/logs/log_msg.go | 0 .../logs/log_msg_test.go | 0 pkg/{infrastructure => core}/logs/log_test.go | 0 pkg/{infrastructure => core}/logs/logger.go | 0 .../logs/logger_test.go | 0 .../logs/multifile.go | 0 .../logs/multifile_test.go | 0 pkg/{infrastructure => core}/logs/slack.go | 0 pkg/{infrastructure => core}/logs/smtp.go | 0 .../logs/smtp_test.go | 0 .../session/README.md | 0 .../session/couchbase/sess_couchbase.go | 2 +- .../session/ledis/ledis_session.go | 2 +- .../session/memcache/sess_memcache.go | 2 +- .../session/mysql/sess_mysql.go | 2 +- .../session/postgres/sess_postgresql.go | 2 +- .../session/redis/sess_redis.go | 2 +- .../session/redis/sess_redis_test.go | 2 +- .../session/redis_cluster/redis_cluster.go | 2 +- .../redis_sentinel/sess_redis_sentinel.go | 2 +- .../sess_redis_sentinel_test.go | 2 +- .../session/sess_cookie.go | 0 .../session/sess_cookie_test.go | 0 .../session/sess_file.go | 0 .../session/sess_file_test.go | 0 .../session/sess_mem.go | 0 .../session/sess_mem_test.go | 0 .../session/sess_test.go | 0 .../session/sess_utils.go | 2 +- .../session/session.go | 0 .../session/ssdb/sess_ssdb.go | 2 +- pkg/{infrastructure => core}/utils/caller.go | 0 .../utils/caller_test.go | 0 pkg/{infrastructure => core}/utils/debug.go | 0 .../utils/debug_test.go | 0 pkg/{infrastructure => core}/utils/file.go | 0 .../utils/file_test.go | 0 pkg/{infrastructure => core}/utils/kv.go | 0 pkg/{infrastructure => core}/utils/kv_test.go | 0 pkg/{infrastructure => core}/utils/mail.go | 0 .../utils/mail_test.go | 0 .../utils/pagination/doc.go | 2 +- .../utils/pagination/paginator.go | 0 .../utils/pagination/utils.go | 0 pkg/{infrastructure => core}/utils/rand.go | 0 .../utils/rand_test.go | 0 pkg/{infrastructure => core}/utils/safemap.go | 0 .../utils/safemap_test.go | 0 pkg/{infrastructure => core}/utils/slice.go | 0 .../utils/slice_test.go | 0 .../utils/testdata/grepe.test | 0 pkg/{infrastructure => core}/utils/time.go | 0 pkg/{infrastructure => core}/utils/utils.go | 0 .../utils/utils_test.go | 0 .../validation/README.md | 0 .../validation/util.go | 0 .../validation/util_test.go | 0 .../validation/validation.go | 0 .../validation/validation_test.go | 0 .../validation/validators.go | 2 +- pkg/server/web/admin.go | 2 +- pkg/server/web/admin_controller.go | 2 +- pkg/server/web/admin_test.go | 2 +- pkg/server/web/captcha/captcha.go | 4 ++-- pkg/server/web/captcha/image_test.go | 2 +- pkg/server/web/config.go | 8 +++---- pkg/server/web/config_test.go | 2 +- pkg/server/web/context/context.go | 2 +- pkg/server/web/context/input.go | 2 +- pkg/server/web/context/param/conv.go | 2 +- pkg/server/web/controller.go | 2 +- pkg/server/web/error.go | 2 +- pkg/server/web/hooks.go | 4 ++-- pkg/server/web/pagination/controller.go | 2 +- pkg/server/web/parser.go | 4 ++-- pkg/server/web/router.go | 4 ++-- pkg/server/web/router_test.go | 2 +- pkg/server/web/server.go | 4 ++-- pkg/server/web/staticfile.go | 2 +- pkg/server/web/statistics.go | 2 +- pkg/server/web/template.go | 4 ++-- pkg/server/web/tree.go | 2 +- pkg/task/govenor_command.go | 2 +- pkg/task/task.go | 23 +++++++++---------- 190 files changed, 129 insertions(+), 130 deletions(-) rename pkg/{infrastructure => core}/bean/context.go (100%) rename pkg/{infrastructure => core}/bean/doc.go (100%) rename pkg/{infrastructure => core}/bean/factory.go (100%) rename pkg/{infrastructure => core}/bean/metadata.go (100%) rename pkg/{infrastructure => core}/bean/tag_auto_wire_bean_factory.go (99%) rename pkg/{infrastructure => core}/bean/tag_auto_wire_bean_factory_test.go (100%) rename pkg/{infrastructure => core}/bean/time_type_adapter.go (100%) rename pkg/{infrastructure => core}/bean/time_type_adapter_test.go (100%) rename pkg/{infrastructure => core}/bean/type_adapter.go (100%) rename pkg/{infrastructure => core}/config/base_config_test.go (100%) rename pkg/{infrastructure => core}/config/config.go (100%) rename pkg/{infrastructure => core}/config/config_test.go (100%) rename pkg/{infrastructure => core}/config/env/env.go (97%) rename pkg/{infrastructure => core}/config/env/env_test.go (100%) rename pkg/{infrastructure => core}/config/etcd/config.go (98%) rename pkg/{infrastructure => core}/config/etcd/config_test.go (100%) rename pkg/{infrastructure => core}/config/fake.go (100%) rename pkg/{infrastructure => core}/config/ini.go (100%) rename pkg/{infrastructure => core}/config/ini_test.go (100%) rename pkg/{infrastructure => core}/config/json/json.go (98%) rename pkg/{infrastructure => core}/config/json/json_test.go (99%) rename pkg/{infrastructure => core}/config/xml/xml.go (98%) rename pkg/{infrastructure => core}/config/xml/xml_test.go (98%) rename pkg/{infrastructure => core}/config/yaml/yaml.go (98%) rename pkg/{infrastructure => core}/config/yaml/yaml_test.go (98%) rename pkg/{infrastructure => core}/governor/command.go (100%) rename pkg/{infrastructure => core}/governor/healthcheck.go (100%) rename pkg/{infrastructure => core}/governor/profile.go (98%) rename pkg/{infrastructure => core}/governor/profile_test.go (100%) rename pkg/{infrastructure => core}/logs/README.md (100%) rename pkg/{infrastructure => core}/logs/access_log.go (100%) rename pkg/{infrastructure => core}/logs/access_log_test.go (100%) rename pkg/{infrastructure => core}/logs/alils/alils.go (98%) rename pkg/{infrastructure => core}/logs/alils/config.go (100%) rename pkg/{infrastructure => core}/logs/alils/log.pb.go (100%) rename pkg/{infrastructure => core}/logs/alils/log_config.go (100%) rename pkg/{infrastructure => core}/logs/alils/log_project.go (100%) rename pkg/{infrastructure => core}/logs/alils/log_store.go (100%) rename pkg/{infrastructure => core}/logs/alils/machine_group.go (100%) rename pkg/{infrastructure => core}/logs/alils/request.go (100%) rename pkg/{infrastructure => core}/logs/alils/signature.go (100%) rename pkg/{infrastructure => core}/logs/conn.go (100%) rename pkg/{infrastructure => core}/logs/conn_test.go (100%) rename pkg/{infrastructure => core}/logs/console.go (100%) rename pkg/{infrastructure => core}/logs/console_test.go (100%) rename pkg/{infrastructure => core}/logs/es/es.go (97%) rename pkg/{infrastructure => core}/logs/es/index.go (95%) rename pkg/{infrastructure => core}/logs/es/index_test.go (94%) rename pkg/{infrastructure => core}/logs/file.go (100%) rename pkg/{infrastructure => core}/logs/file_test.go (100%) rename pkg/{infrastructure => core}/logs/formatter.go (100%) rename pkg/{infrastructure => core}/logs/formatter_test.go (100%) rename pkg/{infrastructure => core}/logs/jianliao.go (100%) rename pkg/{infrastructure => core}/logs/jianliao_test.go (100%) rename pkg/{infrastructure => core}/logs/log.go (100%) rename pkg/{infrastructure => core}/logs/log_msg.go (100%) rename pkg/{infrastructure => core}/logs/log_msg_test.go (100%) rename pkg/{infrastructure => core}/logs/log_test.go (100%) rename pkg/{infrastructure => core}/logs/logger.go (100%) rename pkg/{infrastructure => core}/logs/logger_test.go (100%) rename pkg/{infrastructure => core}/logs/multifile.go (100%) rename pkg/{infrastructure => core}/logs/multifile_test.go (100%) rename pkg/{infrastructure => core}/logs/slack.go (100%) rename pkg/{infrastructure => core}/logs/smtp.go (100%) rename pkg/{infrastructure => core}/logs/smtp_test.go (100%) rename pkg/{infrastructure => core}/session/README.md (100%) rename pkg/{infrastructure => core}/session/couchbase/sess_couchbase.go (99%) rename pkg/{infrastructure => core}/session/ledis/ledis_session.go (98%) rename pkg/{infrastructure => core}/session/memcache/sess_memcache.go (99%) rename pkg/{infrastructure => core}/session/mysql/sess_mysql.go (99%) rename pkg/{infrastructure => core}/session/postgres/sess_postgresql.go (99%) rename pkg/{infrastructure => core}/session/redis/sess_redis.go (99%) rename pkg/{infrastructure => core}/session/redis/sess_redis_test.go (97%) rename pkg/{infrastructure => core}/session/redis_cluster/redis_cluster.go (99%) rename pkg/{infrastructure => core}/session/redis_sentinel/sess_redis_sentinel.go (99%) rename pkg/{infrastructure => core}/session/redis_sentinel/sess_redis_sentinel_test.go (97%) rename pkg/{infrastructure => core}/session/sess_cookie.go (100%) rename pkg/{infrastructure => core}/session/sess_cookie_test.go (100%) rename pkg/{infrastructure => core}/session/sess_file.go (100%) rename pkg/{infrastructure => core}/session/sess_file_test.go (100%) rename pkg/{infrastructure => core}/session/sess_mem.go (100%) rename pkg/{infrastructure => core}/session/sess_mem_test.go (100%) rename pkg/{infrastructure => core}/session/sess_test.go (100%) rename pkg/{infrastructure => core}/session/sess_utils.go (99%) rename pkg/{infrastructure => core}/session/session.go (100%) rename pkg/{infrastructure => core}/session/ssdb/sess_ssdb.go (98%) rename pkg/{infrastructure => core}/utils/caller.go (100%) rename pkg/{infrastructure => core}/utils/caller_test.go (100%) rename pkg/{infrastructure => core}/utils/debug.go (100%) rename pkg/{infrastructure => core}/utils/debug_test.go (100%) rename pkg/{infrastructure => core}/utils/file.go (100%) rename pkg/{infrastructure => core}/utils/file_test.go (100%) rename pkg/{infrastructure => core}/utils/kv.go (100%) rename pkg/{infrastructure => core}/utils/kv_test.go (100%) rename pkg/{infrastructure => core}/utils/mail.go (100%) rename pkg/{infrastructure => core}/utils/mail_test.go (100%) rename pkg/{infrastructure => core}/utils/pagination/doc.go (95%) rename pkg/{infrastructure => core}/utils/pagination/paginator.go (100%) rename pkg/{infrastructure => core}/utils/pagination/utils.go (100%) rename pkg/{infrastructure => core}/utils/rand.go (100%) rename pkg/{infrastructure => core}/utils/rand_test.go (100%) rename pkg/{infrastructure => core}/utils/safemap.go (100%) rename pkg/{infrastructure => core}/utils/safemap_test.go (100%) rename pkg/{infrastructure => core}/utils/slice.go (100%) rename pkg/{infrastructure => core}/utils/slice_test.go (100%) rename pkg/{infrastructure => core}/utils/testdata/grepe.test (100%) rename pkg/{infrastructure => core}/utils/time.go (100%) rename pkg/{infrastructure => core}/utils/utils.go (100%) rename pkg/{infrastructure => core}/utils/utils_test.go (100%) rename pkg/{infrastructure => core}/validation/README.md (100%) rename pkg/{infrastructure => core}/validation/util.go (100%) rename pkg/{infrastructure => core}/validation/util_test.go (100%) rename pkg/{infrastructure => core}/validation/validation.go (100%) rename pkg/{infrastructure => core}/validation/validation_test.go (100%) rename pkg/{infrastructure => core}/validation/validators.go (99%) diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go index 3127416f..5ba78511 100644 --- a/pkg/adapter/admin.go +++ b/pkg/adapter/admin.go @@ -17,7 +17,7 @@ package adapter import ( "time" - _ "github.com/astaxie/beego/pkg/infrastructure/governor" + _ "github.com/astaxie/beego/pkg/core/governor" "github.com/astaxie/beego/pkg/server/web" ) diff --git a/pkg/adapter/config.go b/pkg/adapter/config.go index 1491722c..3975f5eb 100644 --- a/pkg/adapter/config.go +++ b/pkg/adapter/config.go @@ -18,7 +18,7 @@ import ( context2 "context" "github.com/astaxie/beego/pkg/adapter/session" - newCfg "github.com/astaxie/beego/pkg/infrastructure/config" + newCfg "github.com/astaxie/beego/pkg/core/config" "github.com/astaxie/beego/pkg/server/web" ) diff --git a/pkg/adapter/config/adapter.go b/pkg/adapter/config/adapter.go index f74b3ff9..8506228f 100644 --- a/pkg/adapter/config/adapter.go +++ b/pkg/adapter/config/adapter.go @@ -19,7 +19,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) type newToOldConfigerAdapter struct { diff --git a/pkg/adapter/config/config.go b/pkg/adapter/config/config.go index c870a15a..821379f4 100644 --- a/pkg/adapter/config/config.go +++ b/pkg/adapter/config/config.go @@ -41,7 +41,7 @@ package config import ( - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) // Configer defines how to get and set value from configuration raw data. diff --git a/pkg/adapter/config/env/env.go b/pkg/adapter/config/env/env.go index 77d7b53c..bac80576 100644 --- a/pkg/adapter/config/env/env.go +++ b/pkg/adapter/config/env/env.go @@ -17,7 +17,7 @@ package env import ( - "github.com/astaxie/beego/pkg/infrastructure/config/env" + "github.com/astaxie/beego/pkg/core/config/env" ) // Get returns a value by key. diff --git a/pkg/adapter/config/fake.go b/pkg/adapter/config/fake.go index fac96b41..acbd52e5 100644 --- a/pkg/adapter/config/fake.go +++ b/pkg/adapter/config/fake.go @@ -15,7 +15,7 @@ package config import ( - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) // NewFakeConfig return a fake Configer diff --git a/pkg/adapter/config/json.go b/pkg/adapter/config/json.go index d0fe4d09..69c87568 100644 --- a/pkg/adapter/config/json.go +++ b/pkg/adapter/config/json.go @@ -15,5 +15,5 @@ package config import ( - _ "github.com/astaxie/beego/pkg/infrastructure/config/json" + _ "github.com/astaxie/beego/pkg/core/config/json" ) diff --git a/pkg/adapter/config/xml/xml.go b/pkg/adapter/config/xml/xml.go index f96cdcd6..2744e335 100644 --- a/pkg/adapter/config/xml/xml.go +++ b/pkg/adapter/config/xml/xml.go @@ -30,5 +30,5 @@ package xml import ( - _ "github.com/astaxie/beego/pkg/infrastructure/config/xml" + _ "github.com/astaxie/beego/pkg/core/config/xml" ) diff --git a/pkg/adapter/config/yaml/yaml.go b/pkg/adapter/config/yaml/yaml.go index bc2398e9..c5325ccd 100644 --- a/pkg/adapter/config/yaml/yaml.go +++ b/pkg/adapter/config/yaml/yaml.go @@ -30,5 +30,5 @@ package yaml import ( - _ "github.com/astaxie/beego/pkg/infrastructure/config/yaml" + _ "github.com/astaxie/beego/pkg/core/config/yaml" ) diff --git a/pkg/adapter/log.go b/pkg/adapter/log.go index d9ff6e0c..0d7d94c0 100644 --- a/pkg/adapter/log.go +++ b/pkg/adapter/log.go @@ -17,9 +17,9 @@ package adapter import ( "strings" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" - webLog "github.com/astaxie/beego/pkg/infrastructure/logs" + webLog "github.com/astaxie/beego/pkg/core/logs" ) // Log levels to control the logging output. diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go index 6af2c26c..df5db84f 100644 --- a/pkg/adapter/metric/prometheus.go +++ b/pkg/adapter/metric/prometheus.go @@ -24,7 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" "github.com/astaxie/beego/pkg/server/web" ) diff --git a/pkg/adapter/orm/orm.go b/pkg/adapter/orm/orm.go index f8463ea2..61990256 100644 --- a/pkg/adapter/orm/orm.go +++ b/pkg/adapter/orm/orm.go @@ -60,7 +60,7 @@ import ( "github.com/astaxie/beego/pkg/client/orm" "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // DebugQueries define the debug diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go index bce09641..aa3bc724 100644 --- a/pkg/adapter/session/couchbase/sess_couchbase.go +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -37,7 +37,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - beecb "github.com/astaxie/beego/pkg/infrastructure/session/couchbase" + beecb "github.com/astaxie/beego/pkg/core/session/couchbase" ) // SessionStore store each session diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go index 96198837..db47b375 100644 --- a/pkg/adapter/session/ledis/ledis_session.go +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - beeLedis "github.com/astaxie/beego/pkg/infrastructure/session/ledis" + beeLedis "github.com/astaxie/beego/pkg/core/session/ledis" ) // SessionStore ledis session store diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go index 8afa79aa..9f39cf5c 100644 --- a/pkg/adapter/session/memcache/sess_memcache.go +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beemem "github.com/astaxie/beego/pkg/infrastructure/session/memcache" + beemem "github.com/astaxie/beego/pkg/core/session/memcache" ) // SessionStore memcache session store diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go index 1850a380..550556c8 100644 --- a/pkg/adapter/session/mysql/sess_mysql.go +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -45,7 +45,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - "github.com/astaxie/beego/pkg/infrastructure/session/mysql" + "github.com/astaxie/beego/pkg/core/session/mysql" // import mysql driver _ "github.com/go-sql-driver/mysql" diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go index de1adbc4..76361533 100644 --- a/pkg/adapter/session/postgres/sess_postgresql.go +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -58,7 +58,7 @@ import ( // import postgresql Driver _ "github.com/lib/pq" - "github.com/astaxie/beego/pkg/infrastructure/session/postgres" + "github.com/astaxie/beego/pkg/core/session/postgres" ) // SessionStore postgresql session store diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go index 11177a4d..259e998c 100644 --- a/pkg/adapter/session/provider_adapter.go +++ b/pkg/adapter/session/provider_adapter.go @@ -17,7 +17,7 @@ package session import ( "context" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) type oldToNewProviderAdapter struct { diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go index 6c521e50..d4a17b84 100644 --- a/pkg/adapter/session/redis/sess_redis.go +++ b/pkg/adapter/session/redis/sess_redis.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beeRedis "github.com/astaxie/beego/pkg/infrastructure/session/redis" + beeRedis "github.com/astaxie/beego/pkg/core/session/redis" ) // MaxPoolSize redis max pool size diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go index 03a805e4..325efa25 100644 --- a/pkg/adapter/session/redis_cluster/redis_cluster.go +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -37,7 +37,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - cluster "github.com/astaxie/beego/pkg/infrastructure/session/redis_cluster" + cluster "github.com/astaxie/beego/pkg/core/session/redis_cluster" ) // MaxPoolSize redis_cluster max pool size diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go index f5eb8a4f..0306400d 100644 --- a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - sentinel "github.com/astaxie/beego/pkg/infrastructure/session/redis_sentinel" + sentinel "github.com/astaxie/beego/pkg/core/session/redis_sentinel" ) // DefaultPoolSize redis_sentinel default pool size diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go index 32216040..f28b0620 100644 --- a/pkg/adapter/session/sess_cookie.go +++ b/pkg/adapter/session/sess_cookie.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // CookieSessionStore Cookie SessionStore diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go index b9648998..5aa5bc1e 100644 --- a/pkg/adapter/session/sess_file.go +++ b/pkg/adapter/session/sess_file.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // FileSessionStore File session store diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go index 818c8329..ac37d5d3 100644 --- a/pkg/adapter/session/sess_mem.go +++ b/pkg/adapter/session/sess_mem.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // MemSessionStore memory session store. diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go index 3d107198..b5cbc5a1 100644 --- a/pkg/adapter/session/sess_utils.go +++ b/pkg/adapter/session/sess_utils.go @@ -15,7 +15,7 @@ package session import ( - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // EncodeGob encode the obj to gob diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go index eea2f90e..7612854d 100644 --- a/pkg/adapter/session/session.go +++ b/pkg/adapter/session/session.go @@ -32,7 +32,7 @@ import ( "net/http" "os" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // Store contains all data for one session process with specific id. diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go index aee3a364..03c11d61 100644 --- a/pkg/adapter/session/ssdb/sess_ssdb.go +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -6,7 +6,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beeSsdb "github.com/astaxie/beego/pkg/infrastructure/session/ssdb" + beeSsdb "github.com/astaxie/beego/pkg/core/session/ssdb" ) // Provider holds ssdb client and configs diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go index c1a03c38..b8a23937 100644 --- a/pkg/adapter/session/store_adapter.go +++ b/pkg/adapter/session/store_adapter.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) type NewToOldStoreAdapter struct { diff --git a/pkg/adapter/toolbox/healthcheck.go b/pkg/adapter/toolbox/healthcheck.go index 56be8089..42b9e7d0 100644 --- a/pkg/adapter/toolbox/healthcheck.go +++ b/pkg/adapter/toolbox/healthcheck.go @@ -31,7 +31,7 @@ package toolbox import ( - "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/core/governor" ) // AdminCheckList holds health checker map diff --git a/pkg/adapter/toolbox/profile.go b/pkg/adapter/toolbox/profile.go index 16cf80b1..97da05ac 100644 --- a/pkg/adapter/toolbox/profile.go +++ b/pkg/adapter/toolbox/profile.go @@ -19,7 +19,7 @@ import ( "os" "time" - "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/core/governor" ) var startTime = time.Now() diff --git a/pkg/adapter/utils/caller.go b/pkg/adapter/utils/caller.go index d4fcc456..124c68df 100644 --- a/pkg/adapter/utils/caller.go +++ b/pkg/adapter/utils/caller.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // GetFuncName get function name diff --git a/pkg/adapter/utils/debug.go b/pkg/adapter/utils/debug.go index d39f3d3e..6bb381a1 100644 --- a/pkg/adapter/utils/debug.go +++ b/pkg/adapter/utils/debug.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // Display print the data in console diff --git a/pkg/adapter/utils/file.go b/pkg/adapter/utils/file.go index 8979389e..4ed2a8e3 100644 --- a/pkg/adapter/utils/file.go +++ b/pkg/adapter/utils/file.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // SelfPath gets compiled executable file absolute path diff --git a/pkg/adapter/utils/mail.go b/pkg/adapter/utils/mail.go index 35a58756..74ebe25c 100644 --- a/pkg/adapter/utils/mail.go +++ b/pkg/adapter/utils/mail.go @@ -17,7 +17,7 @@ package utils import ( "io" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // Email is the type used for email messages diff --git a/pkg/adapter/utils/pagination/paginator.go b/pkg/adapter/utils/pagination/paginator.go index 4bd4a1b0..1fefb9e0 100644 --- a/pkg/adapter/utils/pagination/paginator.go +++ b/pkg/adapter/utils/pagination/paginator.go @@ -17,7 +17,7 @@ package pagination import ( "net/http" - "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" + "github.com/astaxie/beego/pkg/core/utils/pagination" ) // Paginator within the state of a http request. diff --git a/pkg/adapter/utils/rand.go b/pkg/adapter/utils/rand.go index ae415cf3..c31b633e 100644 --- a/pkg/adapter/utils/rand.go +++ b/pkg/adapter/utils/rand.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // RandomCreateBytes generate random []byte by specify chars. diff --git a/pkg/adapter/utils/safemap.go b/pkg/adapter/utils/safemap.go index 13e7bb46..6771aca4 100644 --- a/pkg/adapter/utils/safemap.go +++ b/pkg/adapter/utils/safemap.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // BeeMap is a map with lock diff --git a/pkg/adapter/utils/slice.go b/pkg/adapter/utils/slice.go index 24d19ad2..a5b852b9 100644 --- a/pkg/adapter/utils/slice.go +++ b/pkg/adapter/utils/slice.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) type reducetype func(interface{}) interface{} diff --git a/pkg/adapter/utils/utils.go b/pkg/adapter/utils/utils.go index 1f3bcd31..21ed49dc 100644 --- a/pkg/adapter/utils/utils.go +++ b/pkg/adapter/utils/utils.go @@ -1,7 +1,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // GetGOPATHs returns all paths in GOPATH variable. diff --git a/pkg/adapter/validation/util.go b/pkg/adapter/validation/util.go index 729712e0..fb5370bf 100644 --- a/pkg/adapter/validation/util.go +++ b/pkg/adapter/validation/util.go @@ -17,7 +17,7 @@ package validation import ( "reflect" - "github.com/astaxie/beego/pkg/infrastructure/validation" + "github.com/astaxie/beego/pkg/core/validation" ) const ( diff --git a/pkg/adapter/validation/validation.go b/pkg/adapter/validation/validation.go index 1cdb8dda..e95fd408 100644 --- a/pkg/adapter/validation/validation.go +++ b/pkg/adapter/validation/validation.go @@ -50,7 +50,7 @@ import ( "fmt" "regexp" - "github.com/astaxie/beego/pkg/infrastructure/validation" + "github.com/astaxie/beego/pkg/core/validation" ) // ValidFormer valid interface diff --git a/pkg/adapter/validation/validators.go b/pkg/adapter/validation/validators.go index 1a063749..152e8aef 100644 --- a/pkg/adapter/validation/validators.go +++ b/pkg/adapter/validation/validators.go @@ -17,7 +17,7 @@ package validation import ( "sync" - "github.com/astaxie/beego/pkg/infrastructure/validation" + "github.com/astaxie/beego/pkg/core/validation" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/pkg/client/orm/do_nothing_orm.go b/pkg/client/orm/do_nothing_orm.go index e27e7f3a..42775b54 100644 --- a/pkg/client/orm/do_nothing_orm.go +++ b/pkg/client/orm/do_nothing_orm.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation diff --git a/pkg/client/orm/filter/bean/default_value_filter.go b/pkg/client/orm/filter/bean/default_value_filter.go index a367a6e2..7da9408d 100644 --- a/pkg/client/orm/filter/bean/default_value_filter.go +++ b/pkg/client/orm/filter/bean/default_value_filter.go @@ -19,10 +19,10 @@ import ( "reflect" "strings" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/infrastructure/bean" + "github.com/astaxie/beego/pkg/core/bean" ) // DefaultValueFilterChainBuilder only works for InsertXXX method, diff --git a/pkg/client/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go index d0c5c537..729c1698 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) const ( diff --git a/pkg/client/orm/filter_orm_decorator_test.go b/pkg/client/orm/filter_orm_decorator_test.go index 7acb7d46..40a3fa2e 100644 --- a/pkg/client/orm/filter_orm_decorator_test.go +++ b/pkg/client/orm/filter_orm_decorator_test.go @@ -21,7 +21,7 @@ import ( "sync" "testing" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/stretchr/testify/assert" ) diff --git a/pkg/client/orm/hints/db_hints.go b/pkg/client/orm/hints/db_hints.go index 7340bd07..c6180529 100644 --- a/pkg/client/orm/hints/db_hints.go +++ b/pkg/client/orm/hints/db_hints.go @@ -15,7 +15,7 @@ package hints import ( - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) const ( diff --git a/pkg/client/orm/migration/ddl.go b/pkg/client/orm/migration/ddl.go index e8b13212..e351b4cc 100644 --- a/pkg/client/orm/migration/ddl.go +++ b/pkg/client/orm/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // Index struct defines the structure of Index Columns diff --git a/pkg/client/orm/migration/migration.go b/pkg/client/orm/migration/migration.go index 3f740594..4a56be58 100644 --- a/pkg/client/orm/migration/migration.go +++ b/pkg/client/orm/migration/migration.go @@ -34,7 +34,7 @@ import ( "time" "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // const the data format for the bee generate migration datatype diff --git a/pkg/client/orm/orm.go b/pkg/client/orm/orm.go index bfb710d1..7d1aace0 100644 --- a/pkg/client/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -63,9 +63,9 @@ import ( "time" "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // DebugQueries define the debug diff --git a/pkg/client/orm/types.go b/pkg/client/orm/types.go index b0c793b7..e43bfd2c 100644 --- a/pkg/client/orm/types.go +++ b/pkg/client/orm/types.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // TableNaming is usually used by model diff --git a/pkg/infrastructure/bean/context.go b/pkg/core/bean/context.go similarity index 100% rename from pkg/infrastructure/bean/context.go rename to pkg/core/bean/context.go diff --git a/pkg/infrastructure/bean/doc.go b/pkg/core/bean/doc.go similarity index 100% rename from pkg/infrastructure/bean/doc.go rename to pkg/core/bean/doc.go diff --git a/pkg/infrastructure/bean/factory.go b/pkg/core/bean/factory.go similarity index 100% rename from pkg/infrastructure/bean/factory.go rename to pkg/core/bean/factory.go diff --git a/pkg/infrastructure/bean/metadata.go b/pkg/core/bean/metadata.go similarity index 100% rename from pkg/infrastructure/bean/metadata.go rename to pkg/core/bean/metadata.go diff --git a/pkg/infrastructure/bean/tag_auto_wire_bean_factory.go b/pkg/core/bean/tag_auto_wire_bean_factory.go similarity index 99% rename from pkg/infrastructure/bean/tag_auto_wire_bean_factory.go rename to pkg/core/bean/tag_auto_wire_bean_factory.go index f80388f9..595b3a02 100644 --- a/pkg/infrastructure/bean/tag_auto_wire_bean_factory.go +++ b/pkg/core/bean/tag_auto_wire_bean_factory.go @@ -22,7 +22,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) const DefaultValueTagKey = "default" diff --git a/pkg/infrastructure/bean/tag_auto_wire_bean_factory_test.go b/pkg/core/bean/tag_auto_wire_bean_factory_test.go similarity index 100% rename from pkg/infrastructure/bean/tag_auto_wire_bean_factory_test.go rename to pkg/core/bean/tag_auto_wire_bean_factory_test.go diff --git a/pkg/infrastructure/bean/time_type_adapter.go b/pkg/core/bean/time_type_adapter.go similarity index 100% rename from pkg/infrastructure/bean/time_type_adapter.go rename to pkg/core/bean/time_type_adapter.go diff --git a/pkg/infrastructure/bean/time_type_adapter_test.go b/pkg/core/bean/time_type_adapter_test.go similarity index 100% rename from pkg/infrastructure/bean/time_type_adapter_test.go rename to pkg/core/bean/time_type_adapter_test.go diff --git a/pkg/infrastructure/bean/type_adapter.go b/pkg/core/bean/type_adapter.go similarity index 100% rename from pkg/infrastructure/bean/type_adapter.go rename to pkg/core/bean/type_adapter.go diff --git a/pkg/infrastructure/config/base_config_test.go b/pkg/core/config/base_config_test.go similarity index 100% rename from pkg/infrastructure/config/base_config_test.go rename to pkg/core/config/base_config_test.go diff --git a/pkg/infrastructure/config/config.go b/pkg/core/config/config.go similarity index 100% rename from pkg/infrastructure/config/config.go rename to pkg/core/config/config.go diff --git a/pkg/infrastructure/config/config_test.go b/pkg/core/config/config_test.go similarity index 100% rename from pkg/infrastructure/config/config_test.go rename to pkg/core/config/config_test.go diff --git a/pkg/infrastructure/config/env/env.go b/pkg/core/config/env/env.go similarity index 97% rename from pkg/infrastructure/config/env/env.go rename to pkg/core/config/env/env.go index 83155b34..0cf1582b 100644 --- a/pkg/infrastructure/config/env/env.go +++ b/pkg/core/config/env/env.go @@ -21,7 +21,7 @@ import ( "os" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) var env *utils.BeeMap diff --git a/pkg/infrastructure/config/env/env_test.go b/pkg/core/config/env/env_test.go similarity index 100% rename from pkg/infrastructure/config/env/env_test.go rename to pkg/core/config/env/env_test.go diff --git a/pkg/infrastructure/config/etcd/config.go b/pkg/core/config/etcd/config.go similarity index 98% rename from pkg/infrastructure/config/etcd/config.go rename to pkg/core/config/etcd/config.go index 94057d73..278cbaa9 100644 --- a/pkg/infrastructure/config/etcd/config.go +++ b/pkg/core/config/etcd/config.go @@ -26,8 +26,8 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc" - "github.com/astaxie/beego/pkg/infrastructure/config" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/pkg/core/logs" ) const etcdOpts = "etcdOpts" diff --git a/pkg/infrastructure/config/etcd/config_test.go b/pkg/core/config/etcd/config_test.go similarity index 100% rename from pkg/infrastructure/config/etcd/config_test.go rename to pkg/core/config/etcd/config_test.go diff --git a/pkg/infrastructure/config/fake.go b/pkg/core/config/fake.go similarity index 100% rename from pkg/infrastructure/config/fake.go rename to pkg/core/config/fake.go diff --git a/pkg/infrastructure/config/ini.go b/pkg/core/config/ini.go similarity index 100% rename from pkg/infrastructure/config/ini.go rename to pkg/core/config/ini.go diff --git a/pkg/infrastructure/config/ini_test.go b/pkg/core/config/ini_test.go similarity index 100% rename from pkg/infrastructure/config/ini_test.go rename to pkg/core/config/ini_test.go diff --git a/pkg/infrastructure/config/json/json.go b/pkg/core/config/json/json.go similarity index 98% rename from pkg/infrastructure/config/json/json.go rename to pkg/core/config/json/json.go index c65eff4d..66546d89 100644 --- a/pkg/infrastructure/config/json/json.go +++ b/pkg/core/config/json/json.go @@ -27,8 +27,8 @@ import ( "github.com/mitchellh/mapstructure" - "github.com/astaxie/beego/pkg/infrastructure/config" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/pkg/core/logs" ) // JSONConfig is a json config parser and implements Config interface. diff --git a/pkg/infrastructure/config/json/json_test.go b/pkg/core/config/json/json_test.go similarity index 99% rename from pkg/infrastructure/config/json/json_test.go rename to pkg/core/config/json/json_test.go index 5275ee57..d4601e39 100644 --- a/pkg/infrastructure/config/json/json_test.go +++ b/pkg/core/config/json/json_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) func TestJsonStartsWithArray(t *testing.T) { diff --git a/pkg/infrastructure/config/xml/xml.go b/pkg/core/config/xml/xml.go similarity index 98% rename from pkg/infrastructure/config/xml/xml.go rename to pkg/core/config/xml/xml.go index e5096b9b..47d4b8c8 100644 --- a/pkg/infrastructure/config/xml/xml.go +++ b/pkg/core/config/xml/xml.go @@ -42,8 +42,8 @@ import ( "github.com/mitchellh/mapstructure" - "github.com/astaxie/beego/pkg/infrastructure/config" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/pkg/core/logs" "github.com/beego/x2j" ) diff --git a/pkg/infrastructure/config/xml/xml_test.go b/pkg/core/config/xml/xml_test.go similarity index 98% rename from pkg/infrastructure/config/xml/xml_test.go rename to pkg/core/config/xml/xml_test.go index 0a3eb313..b110f813 100644 --- a/pkg/infrastructure/config/xml/xml_test.go +++ b/pkg/core/config/xml/xml_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) func TestXML(t *testing.T) { diff --git a/pkg/infrastructure/config/yaml/yaml.go b/pkg/core/config/yaml/yaml.go similarity index 98% rename from pkg/infrastructure/config/yaml/yaml.go rename to pkg/core/config/yaml/yaml.go index 61ea45b9..5a4bfad0 100644 --- a/pkg/infrastructure/config/yaml/yaml.go +++ b/pkg/core/config/yaml/yaml.go @@ -44,8 +44,8 @@ import ( "github.com/beego/goyaml2" "gopkg.in/yaml.v2" - "github.com/astaxie/beego/pkg/infrastructure/config" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/pkg/core/logs" ) // Config is a yaml config parser and implements Config interface. diff --git a/pkg/infrastructure/config/yaml/yaml_test.go b/pkg/core/config/yaml/yaml_test.go similarity index 98% rename from pkg/infrastructure/config/yaml/yaml_test.go rename to pkg/core/config/yaml/yaml_test.go index 1fd4e894..130ce6a2 100644 --- a/pkg/infrastructure/config/yaml/yaml_test.go +++ b/pkg/core/config/yaml/yaml_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/core/config" ) func TestYaml(t *testing.T) { diff --git a/pkg/infrastructure/governor/command.go b/pkg/core/governor/command.go similarity index 100% rename from pkg/infrastructure/governor/command.go rename to pkg/core/governor/command.go diff --git a/pkg/infrastructure/governor/healthcheck.go b/pkg/core/governor/healthcheck.go similarity index 100% rename from pkg/infrastructure/governor/healthcheck.go rename to pkg/core/governor/healthcheck.go diff --git a/pkg/infrastructure/governor/profile.go b/pkg/core/governor/profile.go similarity index 98% rename from pkg/infrastructure/governor/profile.go rename to pkg/core/governor/profile.go index c40cf6ba..de6e1995 100644 --- a/pkg/infrastructure/governor/profile.go +++ b/pkg/core/governor/profile.go @@ -26,7 +26,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) var startTime = time.Now() diff --git a/pkg/infrastructure/governor/profile_test.go b/pkg/core/governor/profile_test.go similarity index 100% rename from pkg/infrastructure/governor/profile_test.go rename to pkg/core/governor/profile_test.go diff --git a/pkg/infrastructure/logs/README.md b/pkg/core/logs/README.md similarity index 100% rename from pkg/infrastructure/logs/README.md rename to pkg/core/logs/README.md diff --git a/pkg/infrastructure/logs/access_log.go b/pkg/core/logs/access_log.go similarity index 100% rename from pkg/infrastructure/logs/access_log.go rename to pkg/core/logs/access_log.go diff --git a/pkg/infrastructure/logs/access_log_test.go b/pkg/core/logs/access_log_test.go similarity index 100% rename from pkg/infrastructure/logs/access_log_test.go rename to pkg/core/logs/access_log_test.go diff --git a/pkg/infrastructure/logs/alils/alils.go b/pkg/core/logs/alils/alils.go similarity index 98% rename from pkg/infrastructure/logs/alils/alils.go rename to pkg/core/logs/alils/alils.go index 0689aae0..812b1b3b 100644 --- a/pkg/infrastructure/logs/alils/alils.go +++ b/pkg/core/logs/alils/alils.go @@ -9,7 +9,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) const ( diff --git a/pkg/infrastructure/logs/alils/config.go b/pkg/core/logs/alils/config.go similarity index 100% rename from pkg/infrastructure/logs/alils/config.go rename to pkg/core/logs/alils/config.go diff --git a/pkg/infrastructure/logs/alils/log.pb.go b/pkg/core/logs/alils/log.pb.go similarity index 100% rename from pkg/infrastructure/logs/alils/log.pb.go rename to pkg/core/logs/alils/log.pb.go diff --git a/pkg/infrastructure/logs/alils/log_config.go b/pkg/core/logs/alils/log_config.go similarity index 100% rename from pkg/infrastructure/logs/alils/log_config.go rename to pkg/core/logs/alils/log_config.go diff --git a/pkg/infrastructure/logs/alils/log_project.go b/pkg/core/logs/alils/log_project.go similarity index 100% rename from pkg/infrastructure/logs/alils/log_project.go rename to pkg/core/logs/alils/log_project.go diff --git a/pkg/infrastructure/logs/alils/log_store.go b/pkg/core/logs/alils/log_store.go similarity index 100% rename from pkg/infrastructure/logs/alils/log_store.go rename to pkg/core/logs/alils/log_store.go diff --git a/pkg/infrastructure/logs/alils/machine_group.go b/pkg/core/logs/alils/machine_group.go similarity index 100% rename from pkg/infrastructure/logs/alils/machine_group.go rename to pkg/core/logs/alils/machine_group.go diff --git a/pkg/infrastructure/logs/alils/request.go b/pkg/core/logs/alils/request.go similarity index 100% rename from pkg/infrastructure/logs/alils/request.go rename to pkg/core/logs/alils/request.go diff --git a/pkg/infrastructure/logs/alils/signature.go b/pkg/core/logs/alils/signature.go similarity index 100% rename from pkg/infrastructure/logs/alils/signature.go rename to pkg/core/logs/alils/signature.go diff --git a/pkg/infrastructure/logs/conn.go b/pkg/core/logs/conn.go similarity index 100% rename from pkg/infrastructure/logs/conn.go rename to pkg/core/logs/conn.go diff --git a/pkg/infrastructure/logs/conn_test.go b/pkg/core/logs/conn_test.go similarity index 100% rename from pkg/infrastructure/logs/conn_test.go rename to pkg/core/logs/conn_test.go diff --git a/pkg/infrastructure/logs/console.go b/pkg/core/logs/console.go similarity index 100% rename from pkg/infrastructure/logs/console.go rename to pkg/core/logs/console.go diff --git a/pkg/infrastructure/logs/console_test.go b/pkg/core/logs/console_test.go similarity index 100% rename from pkg/infrastructure/logs/console_test.go rename to pkg/core/logs/console_test.go diff --git a/pkg/infrastructure/logs/es/es.go b/pkg/core/logs/es/es.go similarity index 97% rename from pkg/infrastructure/logs/es/es.go rename to pkg/core/logs/es/es.go index c4090eab..a150c7b3 100644 --- a/pkg/infrastructure/logs/es/es.go +++ b/pkg/core/logs/es/es.go @@ -12,7 +12,7 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // NewES returns a LoggerInterface diff --git a/pkg/infrastructure/logs/es/index.go b/pkg/core/logs/es/index.go similarity index 95% rename from pkg/infrastructure/logs/es/index.go rename to pkg/core/logs/es/index.go index 9796987e..5b2d3d59 100644 --- a/pkg/infrastructure/logs/es/index.go +++ b/pkg/core/logs/es/index.go @@ -17,7 +17,7 @@ package es import ( "fmt" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // IndexNaming generate the index name diff --git a/pkg/infrastructure/logs/es/index_test.go b/pkg/core/logs/es/index_test.go similarity index 94% rename from pkg/infrastructure/logs/es/index_test.go rename to pkg/core/logs/es/index_test.go index 4cdf9b02..25cfa5ed 100644 --- a/pkg/infrastructure/logs/es/index_test.go +++ b/pkg/core/logs/es/index_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) func TestDefaultIndexNaming_IndexName(t *testing.T) { diff --git a/pkg/infrastructure/logs/file.go b/pkg/core/logs/file.go similarity index 100% rename from pkg/infrastructure/logs/file.go rename to pkg/core/logs/file.go diff --git a/pkg/infrastructure/logs/file_test.go b/pkg/core/logs/file_test.go similarity index 100% rename from pkg/infrastructure/logs/file_test.go rename to pkg/core/logs/file_test.go diff --git a/pkg/infrastructure/logs/formatter.go b/pkg/core/logs/formatter.go similarity index 100% rename from pkg/infrastructure/logs/formatter.go rename to pkg/core/logs/formatter.go diff --git a/pkg/infrastructure/logs/formatter_test.go b/pkg/core/logs/formatter_test.go similarity index 100% rename from pkg/infrastructure/logs/formatter_test.go rename to pkg/core/logs/formatter_test.go diff --git a/pkg/infrastructure/logs/jianliao.go b/pkg/core/logs/jianliao.go similarity index 100% rename from pkg/infrastructure/logs/jianliao.go rename to pkg/core/logs/jianliao.go diff --git a/pkg/infrastructure/logs/jianliao_test.go b/pkg/core/logs/jianliao_test.go similarity index 100% rename from pkg/infrastructure/logs/jianliao_test.go rename to pkg/core/logs/jianliao_test.go diff --git a/pkg/infrastructure/logs/log.go b/pkg/core/logs/log.go similarity index 100% rename from pkg/infrastructure/logs/log.go rename to pkg/core/logs/log.go diff --git a/pkg/infrastructure/logs/log_msg.go b/pkg/core/logs/log_msg.go similarity index 100% rename from pkg/infrastructure/logs/log_msg.go rename to pkg/core/logs/log_msg.go diff --git a/pkg/infrastructure/logs/log_msg_test.go b/pkg/core/logs/log_msg_test.go similarity index 100% rename from pkg/infrastructure/logs/log_msg_test.go rename to pkg/core/logs/log_msg_test.go diff --git a/pkg/infrastructure/logs/log_test.go b/pkg/core/logs/log_test.go similarity index 100% rename from pkg/infrastructure/logs/log_test.go rename to pkg/core/logs/log_test.go diff --git a/pkg/infrastructure/logs/logger.go b/pkg/core/logs/logger.go similarity index 100% rename from pkg/infrastructure/logs/logger.go rename to pkg/core/logs/logger.go diff --git a/pkg/infrastructure/logs/logger_test.go b/pkg/core/logs/logger_test.go similarity index 100% rename from pkg/infrastructure/logs/logger_test.go rename to pkg/core/logs/logger_test.go diff --git a/pkg/infrastructure/logs/multifile.go b/pkg/core/logs/multifile.go similarity index 100% rename from pkg/infrastructure/logs/multifile.go rename to pkg/core/logs/multifile.go diff --git a/pkg/infrastructure/logs/multifile_test.go b/pkg/core/logs/multifile_test.go similarity index 100% rename from pkg/infrastructure/logs/multifile_test.go rename to pkg/core/logs/multifile_test.go diff --git a/pkg/infrastructure/logs/slack.go b/pkg/core/logs/slack.go similarity index 100% rename from pkg/infrastructure/logs/slack.go rename to pkg/core/logs/slack.go diff --git a/pkg/infrastructure/logs/smtp.go b/pkg/core/logs/smtp.go similarity index 100% rename from pkg/infrastructure/logs/smtp.go rename to pkg/core/logs/smtp.go diff --git a/pkg/infrastructure/logs/smtp_test.go b/pkg/core/logs/smtp_test.go similarity index 100% rename from pkg/infrastructure/logs/smtp_test.go rename to pkg/core/logs/smtp_test.go diff --git a/pkg/infrastructure/session/README.md b/pkg/core/session/README.md similarity index 100% rename from pkg/infrastructure/session/README.md rename to pkg/core/session/README.md diff --git a/pkg/infrastructure/session/couchbase/sess_couchbase.go b/pkg/core/session/couchbase/sess_couchbase.go similarity index 99% rename from pkg/infrastructure/session/couchbase/sess_couchbase.go rename to pkg/core/session/couchbase/sess_couchbase.go index ddb4be58..97463d70 100644 --- a/pkg/infrastructure/session/couchbase/sess_couchbase.go +++ b/pkg/core/session/couchbase/sess_couchbase.go @@ -40,7 +40,7 @@ import ( couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) var couchbpder = &Provider{} diff --git a/pkg/infrastructure/session/ledis/ledis_session.go b/pkg/core/session/ledis/ledis_session.go similarity index 98% rename from pkg/infrastructure/session/ledis/ledis_session.go rename to pkg/core/session/ledis/ledis_session.go index 74bf9b65..5791059d 100644 --- a/pkg/infrastructure/session/ledis/ledis_session.go +++ b/pkg/core/session/ledis/ledis_session.go @@ -11,7 +11,7 @@ import ( "github.com/ledisdb/ledisdb/config" "github.com/ledisdb/ledisdb/ledis" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) var ( diff --git a/pkg/infrastructure/session/memcache/sess_memcache.go b/pkg/core/session/memcache/sess_memcache.go similarity index 99% rename from pkg/infrastructure/session/memcache/sess_memcache.go rename to pkg/core/session/memcache/sess_memcache.go index 57df2844..d2b5ed49 100644 --- a/pkg/infrastructure/session/memcache/sess_memcache.go +++ b/pkg/core/session/memcache/sess_memcache.go @@ -38,7 +38,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" "github.com/bradfitz/gomemcache/memcache" ) diff --git a/pkg/infrastructure/session/mysql/sess_mysql.go b/pkg/core/session/mysql/sess_mysql.go similarity index 99% rename from pkg/infrastructure/session/mysql/sess_mysql.go rename to pkg/core/session/mysql/sess_mysql.go index fe1d69dc..964b0b2e 100644 --- a/pkg/infrastructure/session/mysql/sess_mysql.go +++ b/pkg/core/session/mysql/sess_mysql.go @@ -47,7 +47,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" // import mysql driver _ "github.com/go-sql-driver/mysql" ) diff --git a/pkg/infrastructure/session/postgres/sess_postgresql.go b/pkg/core/session/postgres/sess_postgresql.go similarity index 99% rename from pkg/infrastructure/session/postgres/sess_postgresql.go rename to pkg/core/session/postgres/sess_postgresql.go index 2fadbed0..29223d4e 100644 --- a/pkg/infrastructure/session/postgres/sess_postgresql.go +++ b/pkg/core/session/postgres/sess_postgresql.go @@ -57,7 +57,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" // import postgresql Driver _ "github.com/lib/pq" ) diff --git a/pkg/infrastructure/session/redis/sess_redis.go b/pkg/core/session/redis/sess_redis.go similarity index 99% rename from pkg/infrastructure/session/redis/sess_redis.go rename to pkg/core/session/redis/sess_redis.go index c7bfbcbf..bbd94019 100644 --- a/pkg/infrastructure/session/redis/sess_redis.go +++ b/pkg/core/session/redis/sess_redis.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" "github.com/go-redis/redis/v7" ) diff --git a/pkg/infrastructure/session/redis/sess_redis_test.go b/pkg/core/session/redis/sess_redis_test.go similarity index 97% rename from pkg/infrastructure/session/redis/sess_redis_test.go rename to pkg/core/session/redis/sess_redis_test.go index df77204d..f45d3051 100644 --- a/pkg/infrastructure/session/redis/sess_redis_test.go +++ b/pkg/core/session/redis/sess_redis_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) func TestRedis(t *testing.T) { diff --git a/pkg/infrastructure/session/redis_cluster/redis_cluster.go b/pkg/core/session/redis_cluster/redis_cluster.go similarity index 99% rename from pkg/infrastructure/session/redis_cluster/redis_cluster.go rename to pkg/core/session/redis_cluster/redis_cluster.go index 95907a5f..42841cb4 100644 --- a/pkg/infrastructure/session/redis_cluster/redis_cluster.go +++ b/pkg/core/session/redis_cluster/redis_cluster.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" rediss "github.com/go-redis/redis/v7" ) diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go b/pkg/core/session/redis_sentinel/sess_redis_sentinel.go similarity index 99% rename from pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go rename to pkg/core/session/redis_sentinel/sess_redis_sentinel.go index 1b9c841b..b07acdc0 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/core/session/redis_sentinel/sess_redis_sentinel.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" "github.com/go-redis/redis/v7" ) diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go similarity index 97% rename from pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go rename to pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go index fcec9806..2cc21a5a 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) func TestRedisSentinel(t *testing.T) { diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/core/session/sess_cookie.go similarity index 100% rename from pkg/infrastructure/session/sess_cookie.go rename to pkg/core/session/sess_cookie.go diff --git a/pkg/infrastructure/session/sess_cookie_test.go b/pkg/core/session/sess_cookie_test.go similarity index 100% rename from pkg/infrastructure/session/sess_cookie_test.go rename to pkg/core/session/sess_cookie_test.go diff --git a/pkg/infrastructure/session/sess_file.go b/pkg/core/session/sess_file.go similarity index 100% rename from pkg/infrastructure/session/sess_file.go rename to pkg/core/session/sess_file.go diff --git a/pkg/infrastructure/session/sess_file_test.go b/pkg/core/session/sess_file_test.go similarity index 100% rename from pkg/infrastructure/session/sess_file_test.go rename to pkg/core/session/sess_file_test.go diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/core/session/sess_mem.go similarity index 100% rename from pkg/infrastructure/session/sess_mem.go rename to pkg/core/session/sess_mem.go diff --git a/pkg/infrastructure/session/sess_mem_test.go b/pkg/core/session/sess_mem_test.go similarity index 100% rename from pkg/infrastructure/session/sess_mem_test.go rename to pkg/core/session/sess_mem_test.go diff --git a/pkg/infrastructure/session/sess_test.go b/pkg/core/session/sess_test.go similarity index 100% rename from pkg/infrastructure/session/sess_test.go rename to pkg/core/session/sess_test.go diff --git a/pkg/infrastructure/session/sess_utils.go b/pkg/core/session/sess_utils.go similarity index 99% rename from pkg/infrastructure/session/sess_utils.go rename to pkg/core/session/sess_utils.go index 906e1c4b..5f97d1a4 100644 --- a/pkg/infrastructure/session/sess_utils.go +++ b/pkg/core/session/sess_utils.go @@ -29,7 +29,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) func init() { diff --git a/pkg/infrastructure/session/session.go b/pkg/core/session/session.go similarity index 100% rename from pkg/infrastructure/session/session.go rename to pkg/core/session/session.go diff --git a/pkg/infrastructure/session/ssdb/sess_ssdb.go b/pkg/core/session/ssdb/sess_ssdb.go similarity index 98% rename from pkg/infrastructure/session/ssdb/sess_ssdb.go rename to pkg/core/session/ssdb/sess_ssdb.go index 6e4f341e..274c6b35 100644 --- a/pkg/infrastructure/session/ssdb/sess_ssdb.go +++ b/pkg/core/session/ssdb/sess_ssdb.go @@ -8,7 +8,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" "github.com/ssdb/gossdb/ssdb" ) diff --git a/pkg/infrastructure/utils/caller.go b/pkg/core/utils/caller.go similarity index 100% rename from pkg/infrastructure/utils/caller.go rename to pkg/core/utils/caller.go diff --git a/pkg/infrastructure/utils/caller_test.go b/pkg/core/utils/caller_test.go similarity index 100% rename from pkg/infrastructure/utils/caller_test.go rename to pkg/core/utils/caller_test.go diff --git a/pkg/infrastructure/utils/debug.go b/pkg/core/utils/debug.go similarity index 100% rename from pkg/infrastructure/utils/debug.go rename to pkg/core/utils/debug.go diff --git a/pkg/infrastructure/utils/debug_test.go b/pkg/core/utils/debug_test.go similarity index 100% rename from pkg/infrastructure/utils/debug_test.go rename to pkg/core/utils/debug_test.go diff --git a/pkg/infrastructure/utils/file.go b/pkg/core/utils/file.go similarity index 100% rename from pkg/infrastructure/utils/file.go rename to pkg/core/utils/file.go diff --git a/pkg/infrastructure/utils/file_test.go b/pkg/core/utils/file_test.go similarity index 100% rename from pkg/infrastructure/utils/file_test.go rename to pkg/core/utils/file_test.go diff --git a/pkg/infrastructure/utils/kv.go b/pkg/core/utils/kv.go similarity index 100% rename from pkg/infrastructure/utils/kv.go rename to pkg/core/utils/kv.go diff --git a/pkg/infrastructure/utils/kv_test.go b/pkg/core/utils/kv_test.go similarity index 100% rename from pkg/infrastructure/utils/kv_test.go rename to pkg/core/utils/kv_test.go diff --git a/pkg/infrastructure/utils/mail.go b/pkg/core/utils/mail.go similarity index 100% rename from pkg/infrastructure/utils/mail.go rename to pkg/core/utils/mail.go diff --git a/pkg/infrastructure/utils/mail_test.go b/pkg/core/utils/mail_test.go similarity index 100% rename from pkg/infrastructure/utils/mail_test.go rename to pkg/core/utils/mail_test.go diff --git a/pkg/infrastructure/utils/pagination/doc.go b/pkg/core/utils/pagination/doc.go similarity index 95% rename from pkg/infrastructure/utils/pagination/doc.go rename to pkg/core/utils/pagination/doc.go index 86a2ba5d..fb044ff9 100644 --- a/pkg/infrastructure/utils/pagination/doc.go +++ b/pkg/core/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" + import "github.com/astaxie/beego/pkg/core/utils/pagination" type PostsController struct { beego.Controller diff --git a/pkg/infrastructure/utils/pagination/paginator.go b/pkg/core/utils/pagination/paginator.go similarity index 100% rename from pkg/infrastructure/utils/pagination/paginator.go rename to pkg/core/utils/pagination/paginator.go diff --git a/pkg/infrastructure/utils/pagination/utils.go b/pkg/core/utils/pagination/utils.go similarity index 100% rename from pkg/infrastructure/utils/pagination/utils.go rename to pkg/core/utils/pagination/utils.go diff --git a/pkg/infrastructure/utils/rand.go b/pkg/core/utils/rand.go similarity index 100% rename from pkg/infrastructure/utils/rand.go rename to pkg/core/utils/rand.go diff --git a/pkg/infrastructure/utils/rand_test.go b/pkg/core/utils/rand_test.go similarity index 100% rename from pkg/infrastructure/utils/rand_test.go rename to pkg/core/utils/rand_test.go diff --git a/pkg/infrastructure/utils/safemap.go b/pkg/core/utils/safemap.go similarity index 100% rename from pkg/infrastructure/utils/safemap.go rename to pkg/core/utils/safemap.go diff --git a/pkg/infrastructure/utils/safemap_test.go b/pkg/core/utils/safemap_test.go similarity index 100% rename from pkg/infrastructure/utils/safemap_test.go rename to pkg/core/utils/safemap_test.go diff --git a/pkg/infrastructure/utils/slice.go b/pkg/core/utils/slice.go similarity index 100% rename from pkg/infrastructure/utils/slice.go rename to pkg/core/utils/slice.go diff --git a/pkg/infrastructure/utils/slice_test.go b/pkg/core/utils/slice_test.go similarity index 100% rename from pkg/infrastructure/utils/slice_test.go rename to pkg/core/utils/slice_test.go diff --git a/pkg/infrastructure/utils/testdata/grepe.test b/pkg/core/utils/testdata/grepe.test similarity index 100% rename from pkg/infrastructure/utils/testdata/grepe.test rename to pkg/core/utils/testdata/grepe.test diff --git a/pkg/infrastructure/utils/time.go b/pkg/core/utils/time.go similarity index 100% rename from pkg/infrastructure/utils/time.go rename to pkg/core/utils/time.go diff --git a/pkg/infrastructure/utils/utils.go b/pkg/core/utils/utils.go similarity index 100% rename from pkg/infrastructure/utils/utils.go rename to pkg/core/utils/utils.go diff --git a/pkg/infrastructure/utils/utils_test.go b/pkg/core/utils/utils_test.go similarity index 100% rename from pkg/infrastructure/utils/utils_test.go rename to pkg/core/utils/utils_test.go diff --git a/pkg/infrastructure/validation/README.md b/pkg/core/validation/README.md similarity index 100% rename from pkg/infrastructure/validation/README.md rename to pkg/core/validation/README.md diff --git a/pkg/infrastructure/validation/util.go b/pkg/core/validation/util.go similarity index 100% rename from pkg/infrastructure/validation/util.go rename to pkg/core/validation/util.go diff --git a/pkg/infrastructure/validation/util_test.go b/pkg/core/validation/util_test.go similarity index 100% rename from pkg/infrastructure/validation/util_test.go rename to pkg/core/validation/util_test.go diff --git a/pkg/infrastructure/validation/validation.go b/pkg/core/validation/validation.go similarity index 100% rename from pkg/infrastructure/validation/validation.go rename to pkg/core/validation/validation.go diff --git a/pkg/infrastructure/validation/validation_test.go b/pkg/core/validation/validation_test.go similarity index 100% rename from pkg/infrastructure/validation/validation_test.go rename to pkg/core/validation/validation_test.go diff --git a/pkg/infrastructure/validation/validators.go b/pkg/core/validation/validators.go similarity index 99% rename from pkg/infrastructure/validation/validators.go rename to pkg/core/validation/validators.go index 94152b89..1652ee2c 100644 --- a/pkg/infrastructure/validation/validators.go +++ b/pkg/core/validation/validators.go @@ -23,7 +23,7 @@ import ( "time" "unicode/utf8" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/pkg/server/web/admin.go b/pkg/server/web/admin.go index 084190a9..4cac58ba 100644 --- a/pkg/server/web/admin.go +++ b/pkg/server/web/admin.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" ) // BeeAdminApp is the default adminApp used by admin module. diff --git a/pkg/server/web/admin_controller.go b/pkg/server/web/admin_controller.go index dc3a40b5..575362d7 100644 --- a/pkg/server/web/admin_controller.go +++ b/pkg/server/web/admin_controller.go @@ -24,7 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/core/governor" ) type adminController struct { diff --git a/pkg/server/web/admin_test.go b/pkg/server/web/admin_test.go index d04ac319..c33bbf2f 100644 --- a/pkg/server/web/admin_test.go +++ b/pkg/server/web/admin_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/core/governor" ) type SampleDatabaseCheck struct { diff --git a/pkg/server/web/captcha/captcha.go b/pkg/server/web/captcha/captcha.go index 36bc0fcb..876e6074 100644 --- a/pkg/server/web/captcha/captcha.go +++ b/pkg/server/web/captcha/captcha.go @@ -67,9 +67,9 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/captcha/image_test.go b/pkg/server/web/captcha/image_test.go index 36cba386..a6b82f56 100644 --- a/pkg/server/web/captcha/image_test.go +++ b/pkg/server/web/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) type byteCounter struct { diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index bc46b20e..443dfcb8 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -25,11 +25,11 @@ import ( "strings" "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/infrastructure/config" - "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/pkg/core/session" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/config_test.go b/pkg/server/web/config_test.go index 4961d3a9..ce4a4492 100644 --- a/pkg/server/web/config_test.go +++ b/pkg/server/web/config_test.go @@ -19,7 +19,7 @@ import ( "reflect" "testing" - beeJson "github.com/astaxie/beego/pkg/infrastructure/config/json" + beeJson "github.com/astaxie/beego/pkg/core/config/json" ) func TestDefaults(t *testing.T) { diff --git a/pkg/server/web/context/context.go b/pkg/server/web/context/context.go index 78e0a6d6..1a6c00a8 100644 --- a/pkg/server/web/context/context.go +++ b/pkg/server/web/context/context.go @@ -35,7 +35,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // Commonly used mime-types diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index f8657f84..499b61dc 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -29,7 +29,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" ) // Regexes for checking the accept headers diff --git a/pkg/server/web/context/param/conv.go b/pkg/server/web/context/param/conv.go index a96dacdd..73e83e30 100644 --- a/pkg/server/web/context/param/conv.go +++ b/pkg/server/web/context/param/conv.go @@ -4,7 +4,7 @@ import ( "fmt" "reflect" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" beecontext "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/controller.go b/pkg/server/web/controller.go index 2081e647..547c3271 100644 --- a/pkg/server/web/controller.go +++ b/pkg/server/web/controller.go @@ -28,7 +28,7 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/session" "github.com/astaxie/beego/pkg/server/web/context" "github.com/astaxie/beego/pkg/server/web/context/param" diff --git a/pkg/server/web/error.go b/pkg/server/web/error.go index a005c110..d0a8d778 100644 --- a/pkg/server/web/error.go +++ b/pkg/server/web/error.go @@ -24,7 +24,7 @@ import ( "strings" "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/hooks.go b/pkg/server/web/hooks.go index 2f0cb159..15969168 100644 --- a/pkg/server/web/hooks.go +++ b/pkg/server/web/hooks.go @@ -7,8 +7,8 @@ import ( "net/http" "path/filepath" - "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/session" + "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/pkg/core/session" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/pagination/controller.go b/pkg/server/web/pagination/controller.go index 530a72ff..675437f8 100644 --- a/pkg/server/web/pagination/controller.go +++ b/pkg/server/web/pagination/controller.go @@ -15,7 +15,7 @@ package pagination import ( - "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" + "github.com/astaxie/beego/pkg/core/utils/pagination" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/parser.go b/pkg/server/web/parser.go index a4507010..9dfeca56 100644 --- a/pkg/server/web/parser.go +++ b/pkg/server/web/parser.go @@ -31,9 +31,9 @@ import ( "golang.org/x/tools/go/packages" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/context/param" ) diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index a9d1b0cf..07007a2b 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -25,9 +25,9 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" beecontext "github.com/astaxie/beego/pkg/server/web/context" "github.com/astaxie/beego/pkg/server/web/context/param" ) diff --git a/pkg/server/web/router_test.go b/pkg/server/web/router_test.go index a59cde8b..2bc7990c 100644 --- a/pkg/server/web/router_test.go +++ b/pkg/server/web/router_test.go @@ -21,7 +21,7 @@ import ( "strings" "testing" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/server/web/server.go b/pkg/server/web/server.go index 7bd9023d..75523d0c 100644 --- a/pkg/server/web/server.go +++ b/pkg/server/web/server.go @@ -31,10 +31,10 @@ import ( "golang.org/x/crypto/acme/autocert" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/grace" ) diff --git a/pkg/server/web/staticfile.go b/pkg/server/web/staticfile.go index 7b9942f4..4aabbc60 100644 --- a/pkg/server/web/staticfile.go +++ b/pkg/server/web/staticfile.go @@ -26,7 +26,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/core/logs" lru "github.com/hashicorp/golang-lru" "github.com/astaxie/beego/pkg/server/web/context" diff --git a/pkg/server/web/statistics.go b/pkg/server/web/statistics.go index ccc3a1fc..7d5d5800 100644 --- a/pkg/server/web/statistics.go +++ b/pkg/server/web/statistics.go @@ -19,7 +19,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" ) // Statistics struct diff --git a/pkg/server/web/template.go b/pkg/server/web/template.go index 1192a3f2..4a07b0da 100644 --- a/pkg/server/web/template.go +++ b/pkg/server/web/template.go @@ -27,8 +27,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/infrastructure/logs" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/pkg/core/utils" ) var ( diff --git a/pkg/server/web/tree.go b/pkg/server/web/tree.go index 55f68076..9038c010 100644 --- a/pkg/server/web/tree.go +++ b/pkg/server/web/tree.go @@ -19,7 +19,7 @@ import ( "regexp" "strings" - "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/context" ) diff --git a/pkg/task/govenor_command.go b/pkg/task/govenor_command.go index be351dc1..a9583970 100644 --- a/pkg/task/govenor_command.go +++ b/pkg/task/govenor_command.go @@ -21,7 +21,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/core/governor" ) type listTaskCommand struct { diff --git a/pkg/task/task.go b/pkg/task/task.go index e76706e3..8f25a0f3 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -39,7 +39,7 @@ type taskManager struct { started bool } -func newTaskManager()*taskManager{ +func newTaskManager() *taskManager { return &taskManager{ adminTaskList: make(map[string]Tasker), taskLock: sync.RWMutex{}, @@ -53,11 +53,11 @@ func newTaskManager()*taskManager{ var ( globalTaskManager *taskManager - seconds = bounds{0, 59, nil} - minutes = bounds{0, 59, nil} - hours = bounds{0, 23, nil} - days = bounds{1, 31, nil} - months = bounds{1, 12, map[string]uint{ + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + days = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ "jan": 1, "feb": 2, "mar": 3, @@ -436,7 +436,6 @@ func ClearTask() { globalTaskManager.ClearTask() } - // StartTask start all tasks func (m *taskManager) StartTask() { m.taskLock.Lock() @@ -451,7 +450,7 @@ func (m *taskManager) StartTask() { go m.run() } -func(m *taskManager) run() { +func (m *taskManager) run() { now := time.Now().Local() for _, t := range m.adminTaskList { t.SetNext(nil, now) @@ -503,14 +502,14 @@ func(m *taskManager) run() { } // StopTask stop all tasks -func(m *taskManager) StopTask() { +func (m *taskManager) StopTask() { go func() { m.stop <- true }() } // AddTask add task with name -func (m *taskManager)AddTask(taskname string, t Tasker) { +func (m *taskManager) AddTask(taskname string, t Tasker) { isChanged := false m.taskLock.Lock() t.SetNext(nil, time.Now().Local()) @@ -529,7 +528,7 @@ func (m *taskManager)AddTask(taskname string, t Tasker) { } // DeleteTask delete task with name -func(m *taskManager) DeleteTask(taskname string) { +func (m *taskManager) DeleteTask(taskname string) { isChanged := false m.taskLock.Lock() @@ -547,7 +546,7 @@ func(m *taskManager) DeleteTask(taskname string) { } // ClearTask clear all tasks -func(m *taskManager) ClearTask() { +func (m *taskManager) ClearTask() { isChanged := false m.taskLock.Lock() From d8e8f412305b22cb8314f379fe55ef1852ccdecb Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 5 Oct 2020 19:04:22 +0800 Subject: [PATCH 187/207] move core/session to web/session --- pkg/adapter/session/couchbase/sess_couchbase.go | 2 +- pkg/adapter/session/ledis/ledis_session.go | 2 +- pkg/adapter/session/memcache/sess_memcache.go | 2 +- pkg/adapter/session/mysql/sess_mysql.go | 2 +- pkg/adapter/session/postgres/sess_postgresql.go | 2 +- pkg/adapter/session/provider_adapter.go | 2 +- pkg/adapter/session/redis/sess_redis.go | 2 +- pkg/adapter/session/redis_cluster/redis_cluster.go | 2 +- pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go | 2 +- pkg/adapter/session/sess_cookie.go | 2 +- pkg/adapter/session/sess_file.go | 2 +- pkg/adapter/session/sess_mem.go | 2 +- pkg/adapter/session/sess_utils.go | 2 +- pkg/adapter/session/session.go | 2 +- pkg/adapter/session/ssdb/sess_ssdb.go | 2 +- pkg/adapter/session/store_adapter.go | 2 +- pkg/server/web/config.go | 2 +- pkg/server/web/context/input.go | 2 +- pkg/server/web/controller.go | 2 +- pkg/server/web/hooks.go | 2 +- pkg/{core => server/web}/session/README.md | 0 pkg/{core => server/web}/session/couchbase/sess_couchbase.go | 2 +- pkg/{core => server/web}/session/ledis/ledis_session.go | 2 +- pkg/{core => server/web}/session/memcache/sess_memcache.go | 2 +- pkg/{core => server/web}/session/mysql/sess_mysql.go | 2 +- pkg/{core => server/web}/session/postgres/sess_postgresql.go | 2 +- pkg/{core => server/web}/session/redis/sess_redis.go | 2 +- pkg/{core => server/web}/session/redis/sess_redis_test.go | 2 +- .../web}/session/redis_cluster/redis_cluster.go | 2 +- .../web}/session/redis_sentinel/sess_redis_sentinel.go | 2 +- .../web}/session/redis_sentinel/sess_redis_sentinel_test.go | 2 +- pkg/{core => server/web}/session/sess_cookie.go | 0 pkg/{core => server/web}/session/sess_cookie_test.go | 0 pkg/{core => server/web}/session/sess_file.go | 0 pkg/{core => server/web}/session/sess_file_test.go | 0 pkg/{core => server/web}/session/sess_mem.go | 0 pkg/{core => server/web}/session/sess_mem_test.go | 0 pkg/{core => server/web}/session/sess_test.go | 0 pkg/{core => server/web}/session/sess_utils.go | 0 pkg/{core => server/web}/session/session.go | 0 pkg/{core => server/web}/session/ssdb/sess_ssdb.go | 5 +++-- 41 files changed, 33 insertions(+), 32 deletions(-) rename pkg/{core => server/web}/session/README.md (100%) rename pkg/{core => server/web}/session/couchbase/sess_couchbase.go (99%) rename pkg/{core => server/web}/session/ledis/ledis_session.go (98%) rename pkg/{core => server/web}/session/memcache/sess_memcache.go (99%) rename pkg/{core => server/web}/session/mysql/sess_mysql.go (99%) rename pkg/{core => server/web}/session/postgres/sess_postgresql.go (99%) rename pkg/{core => server/web}/session/redis/sess_redis.go (99%) rename pkg/{core => server/web}/session/redis/sess_redis_test.go (97%) rename pkg/{core => server/web}/session/redis_cluster/redis_cluster.go (99%) rename pkg/{core => server/web}/session/redis_sentinel/sess_redis_sentinel.go (99%) rename pkg/{core => server/web}/session/redis_sentinel/sess_redis_sentinel_test.go (97%) rename pkg/{core => server/web}/session/sess_cookie.go (100%) rename pkg/{core => server/web}/session/sess_cookie_test.go (100%) rename pkg/{core => server/web}/session/sess_file.go (100%) rename pkg/{core => server/web}/session/sess_file_test.go (100%) rename pkg/{core => server/web}/session/sess_mem.go (100%) rename pkg/{core => server/web}/session/sess_mem_test.go (100%) rename pkg/{core => server/web}/session/sess_test.go (100%) rename pkg/{core => server/web}/session/sess_utils.go (100%) rename pkg/{core => server/web}/session/session.go (100%) rename pkg/{core => server/web}/session/ssdb/sess_ssdb.go (98%) diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go index aa3bc724..2903dae5 100644 --- a/pkg/adapter/session/couchbase/sess_couchbase.go +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -37,7 +37,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - beecb "github.com/astaxie/beego/pkg/core/session/couchbase" + beecb "github.com/astaxie/beego/pkg/server/web/session/couchbase" ) // SessionStore store each session diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go index db47b375..92d3c96a 100644 --- a/pkg/adapter/session/ledis/ledis_session.go +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - beeLedis "github.com/astaxie/beego/pkg/core/session/ledis" + beeLedis "github.com/astaxie/beego/pkg/server/web/session/ledis" ) // SessionStore ledis session store diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go index 9f39cf5c..ae985881 100644 --- a/pkg/adapter/session/memcache/sess_memcache.go +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beemem "github.com/astaxie/beego/pkg/core/session/memcache" + beemem "github.com/astaxie/beego/pkg/server/web/session/memcache" ) // SessionStore memcache session store diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go index 550556c8..73113792 100644 --- a/pkg/adapter/session/mysql/sess_mysql.go +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -45,7 +45,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - "github.com/astaxie/beego/pkg/core/session/mysql" + "github.com/astaxie/beego/pkg/server/web/session/mysql" // import mysql driver _ "github.com/go-sql-driver/mysql" diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go index 76361533..21a360e8 100644 --- a/pkg/adapter/session/postgres/sess_postgresql.go +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -58,7 +58,7 @@ import ( // import postgresql Driver _ "github.com/lib/pq" - "github.com/astaxie/beego/pkg/core/session/postgres" + "github.com/astaxie/beego/pkg/server/web/session/postgres" ) // SessionStore postgresql session store diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go index 259e998c..84cb4c85 100644 --- a/pkg/adapter/session/provider_adapter.go +++ b/pkg/adapter/session/provider_adapter.go @@ -17,7 +17,7 @@ package session import ( "context" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) type oldToNewProviderAdapter struct { diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go index d4a17b84..47c9d437 100644 --- a/pkg/adapter/session/redis/sess_redis.go +++ b/pkg/adapter/session/redis/sess_redis.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beeRedis "github.com/astaxie/beego/pkg/core/session/redis" + beeRedis "github.com/astaxie/beego/pkg/server/web/session/redis" ) // MaxPoolSize redis max pool size diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go index 325efa25..b741b9ff 100644 --- a/pkg/adapter/session/redis_cluster/redis_cluster.go +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -37,7 +37,7 @@ import ( "net/http" "github.com/astaxie/beego/pkg/adapter/session" - cluster "github.com/astaxie/beego/pkg/core/session/redis_cluster" + cluster "github.com/astaxie/beego/pkg/server/web/session/redis_cluster" ) // MaxPoolSize redis_cluster max pool size diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go index 0306400d..99ff7898 100644 --- a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -38,7 +38,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - sentinel "github.com/astaxie/beego/pkg/core/session/redis_sentinel" + sentinel "github.com/astaxie/beego/pkg/server/web/session/redis_sentinel" ) // DefaultPoolSize redis_sentinel default pool size diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go index f28b0620..8c6c1dc7 100644 --- a/pkg/adapter/session/sess_cookie.go +++ b/pkg/adapter/session/sess_cookie.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // CookieSessionStore Cookie SessionStore diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go index 5aa5bc1e..870b62a6 100644 --- a/pkg/adapter/session/sess_file.go +++ b/pkg/adapter/session/sess_file.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // FileSessionStore File session store diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go index ac37d5d3..faaab548 100644 --- a/pkg/adapter/session/sess_mem.go +++ b/pkg/adapter/session/sess_mem.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // MemSessionStore memory session store. diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go index b5cbc5a1..8cf036e4 100644 --- a/pkg/adapter/session/sess_utils.go +++ b/pkg/adapter/session/sess_utils.go @@ -15,7 +15,7 @@ package session import ( - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // EncodeGob encode the obj to gob diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go index 7612854d..24f587b6 100644 --- a/pkg/adapter/session/session.go +++ b/pkg/adapter/session/session.go @@ -32,7 +32,7 @@ import ( "net/http" "os" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // Store contains all data for one session process with specific id. diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go index 03c11d61..3f2d08d9 100644 --- a/pkg/adapter/session/ssdb/sess_ssdb.go +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -6,7 +6,7 @@ import ( "github.com/astaxie/beego/pkg/adapter/session" - beeSsdb "github.com/astaxie/beego/pkg/core/session/ssdb" + beeSsdb "github.com/astaxie/beego/pkg/server/web/session/ssdb" ) // Provider holds ssdb client and configs diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go index b8a23937..c0de6ac3 100644 --- a/pkg/adapter/session/store_adapter.go +++ b/pkg/adapter/session/store_adapter.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) type NewToOldStoreAdapter struct { diff --git a/pkg/server/web/config.go b/pkg/server/web/config.go index 443dfcb8..9e2a2885 100644 --- a/pkg/server/web/config.go +++ b/pkg/server/web/config.go @@ -27,7 +27,7 @@ import ( "github.com/astaxie/beego/pkg" "github.com/astaxie/beego/pkg/core/config" "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" "github.com/astaxie/beego/pkg/core/utils" "github.com/astaxie/beego/pkg/server/web/context" diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index 499b61dc..746822aa 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -29,7 +29,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) // Regexes for checking the accept headers diff --git a/pkg/server/web/controller.go b/pkg/server/web/controller.go index 547c3271..a8e2ae63 100644 --- a/pkg/server/web/controller.go +++ b/pkg/server/web/controller.go @@ -28,7 +28,7 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" "github.com/astaxie/beego/pkg/server/web/context" "github.com/astaxie/beego/pkg/server/web/context/param" diff --git a/pkg/server/web/hooks.go b/pkg/server/web/hooks.go index 15969168..d7c6cf16 100644 --- a/pkg/server/web/hooks.go +++ b/pkg/server/web/hooks.go @@ -8,8 +8,8 @@ import ( "path/filepath" "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/core/session" "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/session" ) // register MIME type with content type diff --git a/pkg/core/session/README.md b/pkg/server/web/session/README.md similarity index 100% rename from pkg/core/session/README.md rename to pkg/server/web/session/README.md diff --git a/pkg/core/session/couchbase/sess_couchbase.go b/pkg/server/web/session/couchbase/sess_couchbase.go similarity index 99% rename from pkg/core/session/couchbase/sess_couchbase.go rename to pkg/server/web/session/couchbase/sess_couchbase.go index 97463d70..dc616539 100644 --- a/pkg/core/session/couchbase/sess_couchbase.go +++ b/pkg/server/web/session/couchbase/sess_couchbase.go @@ -40,7 +40,7 @@ import ( couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) var couchbpder = &Provider{} diff --git a/pkg/core/session/ledis/ledis_session.go b/pkg/server/web/session/ledis/ledis_session.go similarity index 98% rename from pkg/core/session/ledis/ledis_session.go rename to pkg/server/web/session/ledis/ledis_session.go index 5791059d..e4061a39 100644 --- a/pkg/core/session/ledis/ledis_session.go +++ b/pkg/server/web/session/ledis/ledis_session.go @@ -11,7 +11,7 @@ import ( "github.com/ledisdb/ledisdb/config" "github.com/ledisdb/ledisdb/ledis" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) var ( diff --git a/pkg/core/session/memcache/sess_memcache.go b/pkg/server/web/session/memcache/sess_memcache.go similarity index 99% rename from pkg/core/session/memcache/sess_memcache.go rename to pkg/server/web/session/memcache/sess_memcache.go index d2b5ed49..731df01e 100644 --- a/pkg/core/session/memcache/sess_memcache.go +++ b/pkg/server/web/session/memcache/sess_memcache.go @@ -38,7 +38,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" "github.com/bradfitz/gomemcache/memcache" ) diff --git a/pkg/core/session/mysql/sess_mysql.go b/pkg/server/web/session/mysql/sess_mysql.go similarity index 99% rename from pkg/core/session/mysql/sess_mysql.go rename to pkg/server/web/session/mysql/sess_mysql.go index 964b0b2e..d9b0f6b4 100644 --- a/pkg/core/session/mysql/sess_mysql.go +++ b/pkg/server/web/session/mysql/sess_mysql.go @@ -47,7 +47,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" // import mysql driver _ "github.com/go-sql-driver/mysql" ) diff --git a/pkg/core/session/postgres/sess_postgresql.go b/pkg/server/web/session/postgres/sess_postgresql.go similarity index 99% rename from pkg/core/session/postgres/sess_postgresql.go rename to pkg/server/web/session/postgres/sess_postgresql.go index 29223d4e..c5f8f3fa 100644 --- a/pkg/core/session/postgres/sess_postgresql.go +++ b/pkg/server/web/session/postgres/sess_postgresql.go @@ -57,7 +57,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" // import postgresql Driver _ "github.com/lib/pq" ) diff --git a/pkg/core/session/redis/sess_redis.go b/pkg/server/web/session/redis/sess_redis.go similarity index 99% rename from pkg/core/session/redis/sess_redis.go rename to pkg/server/web/session/redis/sess_redis.go index bbd94019..3b25ae65 100644 --- a/pkg/core/session/redis/sess_redis.go +++ b/pkg/server/web/session/redis/sess_redis.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" "github.com/go-redis/redis/v7" ) diff --git a/pkg/core/session/redis/sess_redis_test.go b/pkg/server/web/session/redis/sess_redis_test.go similarity index 97% rename from pkg/core/session/redis/sess_redis_test.go rename to pkg/server/web/session/redis/sess_redis_test.go index f45d3051..7e63361a 100644 --- a/pkg/core/session/redis/sess_redis_test.go +++ b/pkg/server/web/session/redis/sess_redis_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) func TestRedis(t *testing.T) { diff --git a/pkg/core/session/redis_cluster/redis_cluster.go b/pkg/server/web/session/redis_cluster/redis_cluster.go similarity index 99% rename from pkg/core/session/redis_cluster/redis_cluster.go rename to pkg/server/web/session/redis_cluster/redis_cluster.go index 42841cb4..635ba915 100644 --- a/pkg/core/session/redis_cluster/redis_cluster.go +++ b/pkg/server/web/session/redis_cluster/redis_cluster.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" rediss "github.com/go-redis/redis/v7" ) diff --git a/pkg/core/session/redis_sentinel/sess_redis_sentinel.go b/pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go similarity index 99% rename from pkg/core/session/redis_sentinel/sess_redis_sentinel.go rename to pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go index b07acdc0..4b21242d 100644 --- a/pkg/core/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" "github.com/go-redis/redis/v7" ) diff --git a/pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go similarity index 97% rename from pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go rename to pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 2cc21a5a..8d565bf9 100644 --- a/pkg/core/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/core/session" + "github.com/astaxie/beego/pkg/server/web/session" ) func TestRedisSentinel(t *testing.T) { diff --git a/pkg/core/session/sess_cookie.go b/pkg/server/web/session/sess_cookie.go similarity index 100% rename from pkg/core/session/sess_cookie.go rename to pkg/server/web/session/sess_cookie.go diff --git a/pkg/core/session/sess_cookie_test.go b/pkg/server/web/session/sess_cookie_test.go similarity index 100% rename from pkg/core/session/sess_cookie_test.go rename to pkg/server/web/session/sess_cookie_test.go diff --git a/pkg/core/session/sess_file.go b/pkg/server/web/session/sess_file.go similarity index 100% rename from pkg/core/session/sess_file.go rename to pkg/server/web/session/sess_file.go diff --git a/pkg/core/session/sess_file_test.go b/pkg/server/web/session/sess_file_test.go similarity index 100% rename from pkg/core/session/sess_file_test.go rename to pkg/server/web/session/sess_file_test.go diff --git a/pkg/core/session/sess_mem.go b/pkg/server/web/session/sess_mem.go similarity index 100% rename from pkg/core/session/sess_mem.go rename to pkg/server/web/session/sess_mem.go diff --git a/pkg/core/session/sess_mem_test.go b/pkg/server/web/session/sess_mem_test.go similarity index 100% rename from pkg/core/session/sess_mem_test.go rename to pkg/server/web/session/sess_mem_test.go diff --git a/pkg/core/session/sess_test.go b/pkg/server/web/session/sess_test.go similarity index 100% rename from pkg/core/session/sess_test.go rename to pkg/server/web/session/sess_test.go diff --git a/pkg/core/session/sess_utils.go b/pkg/server/web/session/sess_utils.go similarity index 100% rename from pkg/core/session/sess_utils.go rename to pkg/server/web/session/sess_utils.go diff --git a/pkg/core/session/session.go b/pkg/server/web/session/session.go similarity index 100% rename from pkg/core/session/session.go rename to pkg/server/web/session/session.go diff --git a/pkg/core/session/ssdb/sess_ssdb.go b/pkg/server/web/session/ssdb/sess_ssdb.go similarity index 98% rename from pkg/core/session/ssdb/sess_ssdb.go rename to pkg/server/web/session/ssdb/sess_ssdb.go index 274c6b35..d15f2171 100644 --- a/pkg/core/session/ssdb/sess_ssdb.go +++ b/pkg/server/web/session/ssdb/sess_ssdb.go @@ -8,8 +8,9 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/core/session" "github.com/ssdb/gossdb/ssdb" + + "github.com/astaxie/beego/pkg/server/web/session" ) var ssdbProvider = &Provider{} @@ -87,7 +88,7 @@ func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { // SessionRegenerate regenerate session with new sid and delete oldsid func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { - //conn.Do("setx", key, v, ttl) + // conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { return nil, err From 6aa6c55f07be018183f72c436578d4eb00a4a6f0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 5 Oct 2020 21:54:38 +0800 Subject: [PATCH 188/207] logs Adapter --- pkg/adapter/logs/accesslog.go | 27 +++ pkg/adapter/logs/alils/alils.go | 5 + pkg/adapter/logs/es/es.go | 5 + pkg/adapter/logs/log.go | 346 ++++++++++++++++++++++++++++++++ pkg/adapter/logs/log_adapter.go | 69 +++++++ pkg/adapter/logs/logger.go | 38 ++++ pkg/adapter/logs/logger_test.go | 24 +++ 7 files changed, 514 insertions(+) create mode 100644 pkg/adapter/logs/accesslog.go create mode 100644 pkg/adapter/logs/alils/alils.go create mode 100644 pkg/adapter/logs/es/es.go create mode 100644 pkg/adapter/logs/log.go create mode 100644 pkg/adapter/logs/log_adapter.go create mode 100644 pkg/adapter/logs/logger.go create mode 100644 pkg/adapter/logs/logger_test.go diff --git a/pkg/adapter/logs/accesslog.go b/pkg/adapter/logs/accesslog.go new file mode 100644 index 00000000..cebee92b --- /dev/null +++ b/pkg/adapter/logs/accesslog.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "github.com/astaxie/beego/pkg/core/logs" +) + +// AccessLogRecord struct for holding access log data. +type AccessLogRecord logs.AccessLogRecord + +// AccessLog - Format and print access log. +func AccessLog(r *AccessLogRecord, format string) { + logs.AccessLog((*logs.AccessLogRecord)(r), format) +} diff --git a/pkg/adapter/logs/alils/alils.go b/pkg/adapter/logs/alils/alils.go new file mode 100644 index 00000000..5abbc29f --- /dev/null +++ b/pkg/adapter/logs/alils/alils.go @@ -0,0 +1,5 @@ +package alils + +import ( + _ "github.com/astaxie/beego/pkg/core/logs/alils" +) diff --git a/pkg/adapter/logs/es/es.go b/pkg/adapter/logs/es/es.go new file mode 100644 index 00000000..e0759485 --- /dev/null +++ b/pkg/adapter/logs/es/es.go @@ -0,0 +1,5 @@ +package es + +import ( + _ "github.com/astaxie/beego/pkg/core/logs/es" +) diff --git a/pkg/adapter/logs/log.go b/pkg/adapter/logs/log.go new file mode 100644 index 00000000..6a7045fd --- /dev/null +++ b/pkg/adapter/logs/log.go @@ -0,0 +1,346 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logs provide a general log interface +// Usage: +// +// import "github.com/astaxie/beego/logs" +// +// log := NewLogger(10000) +// log.SetLogger("console", "") +// +// > the first params stand for how many channel +// +// Use it like this: +// +// log.Trace("trace") +// log.Info("info") +// log.Warn("warning") +// log.Debug("debug") +// log.Critical("critical") +// +// more docs http://beego.me/docs/module/logs.md +package logs + +import ( + "log" + "time" + + "github.com/astaxie/beego/pkg/core/logs" +) + +// RFC5424 log message levels. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "smtp" + AdapterConn = "conn" + AdapterEs = "es" + AdapterJianLiao = "jianliao" + AdapterSlack = "slack" + AdapterAliLS = "alils" +) + +// Legacy log level constants to ensure backwards compatibility. +const ( + LevelInfo = LevelInformational + LevelTrace = LevelDebug + LevelWarn = LevelWarning +) + +type newLoggerFunc func() Logger + +// Logger defines the behavior of a log provider. +type Logger interface { + Init(config string) error + WriteMsg(when time.Time, msg string, level int) error + Destroy() + Flush() +} + +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} + +// Register makes a log provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, log newLoggerFunc) { + logs.Register(name, func() logs.Logger { + return &oldToNewAdapter{ + old: log(), + } + }) +} + +// BeeLogger is default logger in beego application. +// it can contain several providers and log message into all providers. +type BeeLogger logs.BeeLogger + +const defaultAsyncMsgLen = 1e3 + +// NewLogger returns a new BeeLogger. +// channelLen means the number of messages in chan(used where asynchronous is true). +// if the buffering chan is full, logger adapters write to file or other way. +func NewLogger(channelLens ...int64) *BeeLogger { + return (*BeeLogger)(logs.NewLogger(channelLens...)) +} + +// Async set the log to asynchronous and start the goroutine +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + (*logs.BeeLogger)(bl).Async(msgLen...) + return bl +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + return (*logs.BeeLogger)(bl).SetLogger(adapterName, configs...) +} + +// DelLogger remove a logger adapter in BeeLogger. +func (bl *BeeLogger) DelLogger(adapterName string) error { + return (*logs.BeeLogger)(bl).DelLogger(adapterName) +} + +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + return (*logs.BeeLogger)(bl).Write(p) +} + +// SetLevel Set log message level. +// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), +// log providers will not even be sent the message. +func (bl *BeeLogger) SetLevel(l int) { + (*logs.BeeLogger)(bl).SetLevel(l) +} + +// GetLevel Get Current log message level. +func (bl *BeeLogger) GetLevel() int { + return (*logs.BeeLogger)(bl).GetLevel() +} + +// SetLogFuncCallDepth set log funcCallDepth +func (bl *BeeLogger) SetLogFuncCallDepth(d int) { + (*logs.BeeLogger)(bl).SetLogFuncCallDepth(d) +} + +// GetLogFuncCallDepth return log funcCallDepth for wrapper +func (bl *BeeLogger) GetLogFuncCallDepth() int { + return (*logs.BeeLogger)(bl).GetLogFuncCallDepth() +} + +// EnableFuncCallDepth enable log funcCallDepth +func (bl *BeeLogger) EnableFuncCallDepth(b bool) { + (*logs.BeeLogger)(bl).EnableFuncCallDepth(b) +} + +// set prefix +func (bl *BeeLogger) SetPrefix(s string) { + (*logs.BeeLogger)(bl).SetPrefix(s) +} + +// Emergency Log EMERGENCY level message. +func (bl *BeeLogger) Emergency(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Emergency(format, v...) +} + +// Alert Log ALERT level message. +func (bl *BeeLogger) Alert(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Alert(format, v...) +} + +// Critical Log CRITICAL level message. +func (bl *BeeLogger) Critical(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Critical(format, v...) +} + +// Error Log ERROR level message. +func (bl *BeeLogger) Error(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Error(format, v...) +} + +// Warning Log WARNING level message. +func (bl *BeeLogger) Warning(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Warning(format, v...) +} + +// Notice Log NOTICE level message. +func (bl *BeeLogger) Notice(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Notice(format, v...) +} + +// Informational Log INFORMATIONAL level message. +func (bl *BeeLogger) Informational(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Informational(format, v...) +} + +// Debug Log DEBUG level message. +func (bl *BeeLogger) Debug(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Debug(format, v...) +} + +// Warn Log WARN level message. +// compatibility alias for Warning() +func (bl *BeeLogger) Warn(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Warn(format, v...) +} + +// Info Log INFO level message. +// compatibility alias for Informational() +func (bl *BeeLogger) Info(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Info(format, v...) +} + +// Trace Log TRACE level message. +// compatibility alias for Debug() +func (bl *BeeLogger) Trace(format string, v ...interface{}) { + (*logs.BeeLogger)(bl).Trace(format, v...) +} + +// Flush flush all chan data. +func (bl *BeeLogger) Flush() { + (*logs.BeeLogger)(bl).Flush() +} + +// Close close logger, flush all chan data and destroy all adapters in BeeLogger. +func (bl *BeeLogger) Close() { + (*logs.BeeLogger)(bl).Close() +} + +// Reset close all outputs, and set bl.outputs to nil +func (bl *BeeLogger) Reset() { + (*logs.BeeLogger)(bl).Reset() +} + +// GetBeeLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return (*BeeLogger)(logs.GetBeeLogger()) +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + return logs.GetLogger(prefixes...) +} + +// Reset will remove all the adapter +func Reset() { + logs.Reset() +} + +// Async set the beelogger with Async mode and hold msglen messages +func Async(msgLen ...int64) *BeeLogger { + return (*BeeLogger)(logs.Async(msgLen...)) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetPrefix sets the prefix +func SetPrefix(s string) { + logs.SetPrefix(s) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + logs.EnableFuncCallDepth(b) +} + +// SetLogFuncCall set the CallDepth, default is 4 +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + logs.SetLogFuncCallDepth(d) +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + return logs.SetLogger(adapter, config...) +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + logs.Emergency(f, v...) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + logs.Alert(f, v...) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + logs.Critical(f, v...) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + logs.Error(f, v...) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + logs.Warning(f, v...) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + logs.Warn(f, v...) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + logs.Notice(f, v...) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + logs.Informational(f, v...) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + logs.Info(f, v...) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + logs.Debug(f, v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + logs.Trace(f, v...) +} diff --git a/pkg/adapter/logs/log_adapter.go b/pkg/adapter/logs/log_adapter.go new file mode 100644 index 00000000..ee517bf0 --- /dev/null +++ b/pkg/adapter/logs/log_adapter.go @@ -0,0 +1,69 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "time" + + "github.com/astaxie/beego/pkg/core/logs" +) + +type oldToNewAdapter struct { + old Logger +} + +func (o *oldToNewAdapter) Init(config string) error { + return o.old.Init(config) +} + +func (o *oldToNewAdapter) WriteMsg(lm *logs.LogMsg) error { + return o.old.WriteMsg(lm.When, lm.OldStyleFormat(), lm.Level) +} + +func (o *oldToNewAdapter) Destroy() { + o.old.Destroy() +} + +func (o *oldToNewAdapter) Flush() { + o.old.Flush() +} + +func (o *oldToNewAdapter) SetFormatter(f logs.LogFormatter) { + panic("unsupported operation, you should not invoke this method") +} + +type newToOldAdapter struct { + n logs.Logger +} + +func (n *newToOldAdapter) Init(config string) error { + return n.n.Init(config) +} + +func (n *newToOldAdapter) WriteMsg(when time.Time, msg string, level int) error { + return n.n.WriteMsg(&logs.LogMsg{ + When: when, + Msg: msg, + Level: level, + }) +} + +func (n *newToOldAdapter) Destroy() { + panic("implement me") +} + +func (n *newToOldAdapter) Flush() { + panic("implement me") +} diff --git a/pkg/adapter/logs/logger.go b/pkg/adapter/logs/logger.go new file mode 100644 index 00000000..419ac9c4 --- /dev/null +++ b/pkg/adapter/logs/logger.go @@ -0,0 +1,38 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "github.com/astaxie/beego/pkg/core/logs" +) + +// ColorByStatus return color by http code +// 2xx return Green +// 3xx return White +// 4xx return Yellow +// 5xx return Red +func ColorByStatus(code int) string { + return logs.ColorByStatus(code) +} + +// ColorByMethod return color by http code +func ColorByMethod(method string) string { + return logs.ColorByMethod(method) +} + +// ResetColor return reset color +func ResetColor() string { + return logs.ResetColor() +} diff --git a/pkg/adapter/logs/logger_test.go b/pkg/adapter/logs/logger_test.go new file mode 100644 index 00000000..9f2cc5a5 --- /dev/null +++ b/pkg/adapter/logs/logger_test.go @@ -0,0 +1,24 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" +) + +func TestBeeLogger_Info(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("file", `{"net":"tcp","addr":":7020"}`) +} From 8cc74652a2c9de34c987665c53ac116e9c7dc1b4 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 5 Oct 2020 22:45:48 +0800 Subject: [PATCH 189/207] Fix: adapter's controller must implement ControllerInterface --- pkg/adapter/controller.go | 10 ++++------ pkg/adapter/orm/db_alias.go | 11 +++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pkg/adapter/controller.go b/pkg/adapter/controller.go index 010add64..c0616962 100644 --- a/pkg/adapter/controller.go +++ b/pkg/adapter/controller.go @@ -18,7 +18,6 @@ import ( "mime/multipart" "net/url" - "github.com/astaxie/beego/pkg/adapter/context" "github.com/astaxie/beego/pkg/adapter/session" webContext "github.com/astaxie/beego/pkg/server/web/context" @@ -61,14 +60,13 @@ func (p ControllerCommentsSlice) Swap(i, j int) { // http context, template and view, session and xsrf. type Controller web.Controller +func (c *Controller) Init(ctx *webContext.Context, controllerName, actionName string, app interface{}) { + (*web.Controller)(c).Init(ctx, controllerName, actionName, app) +} + // ControllerInterface is an interface to uniform all controller handler. type ControllerInterface web.ControllerInterface -// Init generates default values of controller operations. -func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { - (*web.Controller)(c).Init((*webContext.Context)(ctx), controllerName, actionName, app) -} - // Prepare runs after Init before request function execution. func (c *Controller) Prepare() { (*web.Controller)(c).Prepare() diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go index b1f1a724..523b6aee 100644 --- a/pkg/adapter/orm/db_alias.go +++ b/pkg/adapter/orm/db_alias.go @@ -27,12 +27,11 @@ type DriverType orm.DriverType // Enum the Database driver const ( - _ DriverType = iota // int enum type - DRMySQL = orm.DRMySQL - DRSqlite = orm.DRSqlite // sqlite - DROracle = orm.DROracle // oracle - DRPostgres = orm.DRPostgres // pgsql - DRTiDB = orm.DRTiDB // TiDB + DRMySQL = DriverType(orm.DRMySQL) + DRSqlite = DriverType(orm.DRSqlite) // sqlite + DROracle = DriverType(orm.DROracle) // oracle + DRPostgres = DriverType(orm.DRPostgres) // pgsql + DRTiDB = DriverType(orm.DRTiDB) // TiDB ) type DB orm.DB From 66804324f23de1c6487d306edc505d383eac947a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 6 Oct 2020 11:52:24 +0800 Subject: [PATCH 190/207] Fix: Set func call depth as 3 --- pkg/adapter/logs/log.go | 7 ++++--- pkg/core/logs/log.go | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/adapter/logs/log.go b/pkg/adapter/logs/log.go index 6a7045fd..185b2a47 100644 --- a/pkg/adapter/logs/log.go +++ b/pkg/adapter/logs/log.go @@ -86,9 +86,6 @@ type Logger interface { Flush() } -var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} - // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, // it panics. @@ -344,3 +341,7 @@ func Debug(f interface{}, v ...interface{}) { func Trace(f interface{}, v ...interface{}) { logs.Trace(f, v...) } + +func init() { + SetLogFuncCallDepth(4) +} diff --git a/pkg/core/logs/log.go b/pkg/core/logs/log.go index cec8d51d..b05abd3b 100644 --- a/pkg/core/logs/log.go +++ b/pkg/core/logs/log.go @@ -142,7 +142,7 @@ var logMsgPool *sync.Pool func NewLogger(channelLens ...int64) *BeeLogger { bl := new(BeeLogger) bl.level = LevelDebug - bl.loggerFuncCallDepth = 2 + bl.loggerFuncCallDepth = 3 bl.msgChanLen = append(channelLens, 0)[0] if bl.msgChanLen <= 0 { bl.msgChanLen = defaultAsyncMsgLen From 034cb3222e7b51a835333facbad5bba9226b815d Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 6 Oct 2020 11:53:59 +0800 Subject: [PATCH 191/207] Add adapter script which is used to replace v1 package with v2 adapter package --- scripts/adapter.sh | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 scripts/adapter.sh diff --git a/scripts/adapter.sh b/scripts/adapter.sh new file mode 100644 index 00000000..ce2d319a --- /dev/null +++ b/scripts/adapter.sh @@ -0,0 +1,6 @@ +#!/bin/sh + +# using pkg/adapter. Usually you want to migrate to V2 smoothly, you could running this script + +find ./ -name '*.go' -type f -exec sed -i '' -e 's/github.com\/astaxie\/beego/github.com\/astaxie\/beego\/pkg\/adapter/g' {} \; +find ./ -name '*.go' -type f -exec sed -i '' -e 's/"github.com\/astaxie\/beego\/pkg\/adapter"/beego "github.com\/astaxie\/beego\/pkg\/adapter"/g' {} \; From 14c1b765695af47609d2d0cf73a53f7923a7d936 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 8 Oct 2020 17:17:15 +0800 Subject: [PATCH 192/207] remove pkg directory; remove build directory; remove githook directory; --- .travis.yml | 2 +- {pkg/adapter => adapter}/admin.go | 4 +- {pkg/adapter => adapter}/app.go | 6 +- {pkg/adapter => adapter}/beego.go | 6 +- {pkg/adapter => adapter}/build_info.go | 0 {pkg/adapter => adapter}/cache/cache.go | 0 .../cache/cache_adapter.go | 2 +- {pkg/adapter => adapter}/cache/cache_test.go | 0 {pkg/adapter => adapter}/cache/conv.go | 2 +- {pkg/adapter => adapter}/cache/conv_test.go | 0 {pkg/adapter => adapter}/cache/file.go | 2 +- .../cache/memcache/memcache.go | 4 +- .../cache/memcache/memcache_test.go | 2 +- {pkg/adapter => adapter}/cache/memory.go | 2 +- {pkg/adapter => adapter}/cache/redis/redis.go | 4 +- .../cache/redis/redis_test.go | 2 +- {pkg/adapter => adapter}/cache/ssdb/ssdb.go | 4 +- .../cache/ssdb/ssdb_test.go | 2 +- {pkg/adapter => adapter}/config.go | 6 +- {pkg/adapter => adapter}/config/adapter.go | 2 +- {pkg/adapter => adapter}/config/config.go | 2 +- .../adapter => adapter}/config/config_test.go | 0 {pkg/adapter => adapter}/config/env/env.go | 2 +- .../config/env/env_test.go | 0 {pkg/adapter => adapter}/config/fake.go | 2 +- {pkg/adapter => adapter}/config/ini_test.go | 0 {pkg/adapter => adapter}/config/json.go | 2 +- {pkg/adapter => adapter}/config/json_test.go | 0 {pkg/adapter => adapter}/config/xml/xml.go | 2 +- .../config/xml/xml_test.go | 2 +- {pkg/adapter => adapter}/config/yaml/yaml.go | 2 +- .../config/yaml/yaml_test.go | 2 +- .../context/acceptencoder.go | 2 +- {pkg/adapter => adapter}/context/context.go | 2 +- {pkg/adapter => adapter}/context/input.go | 2 +- {pkg/adapter => adapter}/context/output.go | 2 +- {pkg/adapter => adapter}/context/renderer.go | 2 +- {pkg/adapter => adapter}/context/response.go | 0 {pkg/adapter => adapter}/controller.go | 6 +- {pkg/adapter => adapter}/doc.go | 0 {pkg/adapter => adapter}/error.go | 6 +- {pkg/adapter => adapter}/filter.go | 6 +- {pkg/adapter => adapter}/flash.go | 2 +- {pkg/adapter => adapter}/fs.go | 2 +- {pkg/adapter => adapter}/grace/grace.go | 2 +- {pkg/adapter => adapter}/grace/server.go | 2 +- {pkg/adapter => adapter}/httplib/httplib.go | 2 +- .../httplib/httplib_test.go | 0 {pkg/adapter => adapter}/log.go | 4 +- {pkg/adapter => adapter}/logs/accesslog.go | 2 +- adapter/logs/alils/alils.go | 5 + adapter/logs/es/es.go | 5 + {pkg/adapter => adapter}/logs/log.go | 2 +- {pkg/adapter => adapter}/logs/log_adapter.go | 2 +- {pkg/adapter => adapter}/logs/logger.go | 2 +- {pkg/adapter => adapter}/logs/logger_test.go | 0 {pkg/adapter => adapter}/metric/prometheus.go | 20 ++-- .../metric/prometheus_test.go | 2 +- {pkg/adapter => adapter}/migration/ddl.go | 2 +- {pkg/adapter => adapter}/migration/doc.go | 0 .../migration/migration.go | 2 +- {pkg/adapter => adapter}/namespace.go | 6 +- {pkg/adapter => adapter}/orm/cmd.go | 2 +- {pkg/adapter => adapter}/orm/db.go | 2 +- {pkg/adapter => adapter}/orm/db_alias.go | 2 +- {pkg/adapter => adapter}/orm/models.go | 2 +- {pkg/adapter => adapter}/orm/models_boot.go | 2 +- {pkg/adapter => adapter}/orm/models_fields.go | 2 +- {pkg/adapter => adapter}/orm/orm.go | 6 +- {pkg/adapter => adapter}/orm/orm_conds.go | 2 +- {pkg/adapter => adapter}/orm/orm_log.go | 2 +- {pkg/adapter => adapter}/orm/orm_queryset.go | 2 +- {pkg/adapter => adapter}/orm/qb.go | 2 +- {pkg/adapter => adapter}/orm/qb_mysql.go | 2 +- {pkg/adapter => adapter}/orm/qb_tidb.go | 2 +- .../orm/query_setter_adapter.go | 2 +- {pkg/adapter => adapter}/orm/types.go | 2 +- {pkg/adapter => adapter}/orm/utils.go | 2 +- {pkg/adapter => adapter}/orm/utils_test.go | 0 .../plugins/apiauth/apiauth.go | 8 +- .../plugins/apiauth/apiauth_test.go | 0 .../adapter => adapter}/plugins/auth/basic.go | 8 +- .../plugins/authz/authz.go | 8 +- .../plugins/authz/authz_model.conf | 0 .../plugins/authz/authz_policy.csv | 0 .../plugins/authz/authz_test.go | 6 +- {pkg/adapter => adapter}/plugins/cors/cors.go | 8 +- {pkg/adapter => adapter}/policy.go | 6 +- {pkg/adapter => adapter}/router.go | 6 +- .../session/couchbase/sess_couchbase.go | 4 +- .../session/ledis/ledis_session.go | 4 +- .../session/memcache/sess_memcache.go | 4 +- .../session/mysql/sess_mysql.go | 4 +- .../session/postgres/sess_postgresql.go | 4 +- .../session/provider_adapter.go | 2 +- .../session/redis/sess_redis.go | 4 +- .../session/redis_cluster/redis_cluster.go | 4 +- .../redis_sentinel/sess_redis_sentinel.go | 4 +- .../sess_redis_sentinel_test.go | 2 +- .../session/sess_cookie.go | 2 +- .../session/sess_cookie_test.go | 0 {pkg/adapter => adapter}/session/sess_file.go | 2 +- .../session/sess_file_test.go | 0 {pkg/adapter => adapter}/session/sess_mem.go | 2 +- .../session/sess_mem_test.go | 0 {pkg/adapter => adapter}/session/sess_test.go | 0 .../adapter => adapter}/session/sess_utils.go | 2 +- {pkg/adapter => adapter}/session/session.go | 2 +- .../session/ssdb/sess_ssdb.go | 4 +- .../session/store_adapter.go | 2 +- {pkg/adapter => adapter}/swagger/swagger.go | 2 +- {pkg/adapter => adapter}/template.go | 2 +- {pkg/adapter => adapter}/templatefunc.go | 2 +- {pkg/adapter => adapter}/templatefunc_test.go | 0 {pkg/adapter => adapter}/testing/client.go | 2 +- .../toolbox/healthcheck.go | 2 +- {pkg/adapter => adapter}/toolbox/profile.go | 2 +- .../toolbox/profile_test.go | 0 .../adapter => adapter}/toolbox/statistics.go | 2 +- .../toolbox/statistics_test.go | 0 {pkg/adapter => adapter}/toolbox/task.go | 2 +- {pkg/adapter => adapter}/toolbox/task_test.go | 0 {pkg/adapter => adapter}/tree.go | 6 +- {pkg/adapter => adapter}/tree_test.go | 4 +- {pkg/adapter => adapter}/utils/caller.go | 2 +- {pkg/adapter => adapter}/utils/caller_test.go | 0 .../adapter => adapter}/utils/captcha/LICENSE | 0 .../utils/captcha/README.md | 0 .../utils/captcha/captcha.go | 8 +- .../utils/captcha/image.go | 2 +- .../utils/captcha/image_test.go | 2 +- {pkg/adapter => adapter}/utils/debug.go | 2 +- {pkg/adapter => adapter}/utils/debug_test.go | 0 {pkg/adapter => adapter}/utils/file.go | 2 +- {pkg/adapter => adapter}/utils/mail.go | 2 +- {pkg/adapter => adapter}/utils/mail_test.go | 0 .../utils/pagination/controller.go | 6 +- .../utils/pagination/doc.go | 0 .../utils/pagination/paginator.go | 2 +- {pkg/adapter => adapter}/utils/rand.go | 2 +- {pkg/adapter => adapter}/utils/rand_test.go | 0 {pkg/adapter => adapter}/utils/safemap.go | 2 +- .../adapter => adapter}/utils/safemap_test.go | 0 {pkg/adapter => adapter}/utils/slice.go | 2 +- {pkg/adapter => adapter}/utils/slice_test.go | 0 {pkg/adapter => adapter}/utils/utils.go | 2 +- {pkg/adapter => adapter}/validation/util.go | 2 +- .../validation/validation.go | 2 +- .../validation/validation_test.go | 0 .../validation/validators.go | 2 +- build/gobuild-sample.sh | 112 ------------------ build/report_build_info.sh | 52 -------- pkg/build_info.go => build_info.go | 2 +- {pkg/client => client}/cache/README.md | 0 {pkg/client => client}/cache/cache.go | 0 {pkg/client => client}/cache/cache_test.go | 0 {pkg/client => client}/cache/conv.go | 0 {pkg/client => client}/cache/conv_test.go | 0 {pkg/client => client}/cache/file.go | 0 .../cache/memcache/memcache.go | 2 +- .../cache/memcache/memcache_test.go | 2 +- {pkg/client => client}/cache/memory.go | 0 {pkg/client => client}/cache/redis/redis.go | 2 +- .../cache/redis/redis_test.go | 2 +- {pkg/client => client}/cache/ssdb/ssdb.go | 2 +- .../client => client}/cache/ssdb/ssdb_test.go | 2 +- {pkg/client => client}/httplib/README.md | 0 {pkg/client => client}/httplib/filter.go | 0 .../httplib/filter/opentracing/filter.go | 2 +- .../httplib/filter/opentracing/filter_test.go | 2 +- .../httplib/filter/prometheus/filter.go | 2 +- .../httplib/filter/prometheus/filter_test.go | 2 +- {pkg/client => client}/httplib/httplib.go | 0 .../client => client}/httplib/httplib_test.go | 0 .../httplib/testing/client.go | 2 +- {pkg/client => client}/orm/README.md | 0 {pkg/client => client}/orm/cmd.go | 0 {pkg/client => client}/orm/cmd_utils.go | 0 {pkg/client => client}/orm/db.go | 2 +- {pkg/client => client}/orm/db_alias.go | 0 {pkg/client => client}/orm/db_alias_test.go | 0 {pkg/client => client}/orm/db_mysql.go | 0 {pkg/client => client}/orm/db_oracle.go | 2 +- {pkg/client => client}/orm/db_postgres.go | 0 {pkg/client => client}/orm/db_sqlite.go | 2 +- {pkg/client => client}/orm/db_tables.go | 0 {pkg/client => client}/orm/db_tidb.go | 0 {pkg/client => client}/orm/db_utils.go | 0 {pkg/client => client}/orm/do_nothing_orm.go | 2 +- .../orm/do_nothing_orm_test.go | 0 {pkg/client => client}/orm/filter.go | 0 .../orm/filter/bean/default_value_filter.go | 6 +- .../filter/bean/default_value_filter_test.go | 2 +- .../orm/filter/opentracing/filter.go | 2 +- .../orm/filter/opentracing/filter_test.go | 2 +- .../orm/filter/prometheus/filter.go | 2 +- .../orm/filter/prometheus/filter_test.go | 2 +- .../orm/filter_orm_decorator.go | 2 +- .../orm/filter_orm_decorator_test.go | 2 +- {pkg/client => client}/orm/filter_test.go | 0 {pkg/client => client}/orm/hints/db_hints.go | 2 +- .../orm/hints/db_hints_test.go | 0 {pkg/client => client}/orm/invocation.go | 0 {pkg/client => client}/orm/migration/ddl.go | 2 +- {pkg/client => client}/orm/migration/doc.go | 0 .../orm/migration/migration.go | 4 +- .../client => client}/orm/model_utils_test.go | 0 {pkg/client => client}/orm/models.go | 0 {pkg/client => client}/orm/models_boot.go | 0 {pkg/client => client}/orm/models_fields.go | 0 {pkg/client => client}/orm/models_info_f.go | 0 {pkg/client => client}/orm/models_info_m.go | 0 {pkg/client => client}/orm/models_test.go | 12 +- {pkg/client => client}/orm/models_utils.go | 0 .../orm/models_utils_test.go | 0 {pkg/client => client}/orm/orm.go | 8 +- {pkg/client => client}/orm/orm_conds.go | 0 {pkg/client => client}/orm/orm_log.go | 0 {pkg/client => client}/orm/orm_object.go | 0 {pkg/client => client}/orm/orm_querym2m.go | 0 {pkg/client => client}/orm/orm_queryset.go | 2 +- {pkg/client => client}/orm/orm_raw.go | 0 {pkg/client => client}/orm/orm_test.go | 2 +- {pkg/client => client}/orm/qb.go | 0 {pkg/client => client}/orm/qb_mysql.go | 0 {pkg/client => client}/orm/qb_postgres.go | 0 {pkg/client => client}/orm/qb_tidb.go | 0 {pkg/client => client}/orm/types.go | 2 +- {pkg/client => client}/orm/utils.go | 0 {pkg/client => client}/orm/utils_test.go | 0 {pkg/core => core}/bean/context.go | 0 {pkg/core => core}/bean/doc.go | 0 {pkg/core => core}/bean/factory.go | 0 {pkg/core => core}/bean/metadata.go | 0 .../bean/tag_auto_wire_bean_factory.go | 2 +- .../bean/tag_auto_wire_bean_factory_test.go | 0 {pkg/core => core}/bean/time_type_adapter.go | 0 .../bean/time_type_adapter_test.go | 0 {pkg/core => core}/bean/type_adapter.go | 0 {pkg/core => core}/config/base_config_test.go | 0 {pkg/core => core}/config/config.go | 0 {pkg/core => core}/config/config_test.go | 0 {pkg/core => core}/config/env/env.go | 2 +- {pkg/core => core}/config/env/env_test.go | 0 {pkg/core => core}/config/etcd/config.go | 4 +- {pkg/core => core}/config/etcd/config_test.go | 0 {pkg/core => core}/config/fake.go | 0 {pkg/core => core}/config/ini.go | 0 {pkg/core => core}/config/ini_test.go | 0 {pkg/core => core}/config/json/json.go | 4 +- {pkg/core => core}/config/json/json_test.go | 2 +- {pkg/core => core}/config/xml/xml.go | 4 +- {pkg/core => core}/config/xml/xml_test.go | 2 +- {pkg/core => core}/config/yaml/yaml.go | 5 +- {pkg/core => core}/config/yaml/yaml_test.go | 2 +- {pkg/core => core}/governor/command.go | 0 {pkg/core => core}/governor/healthcheck.go | 0 {pkg/core => core}/governor/profile.go | 2 +- {pkg/core => core}/governor/profile_test.go | 0 {pkg/core => core}/logs/README.md | 0 {pkg/core => core}/logs/access_log.go | 0 {pkg/core => core}/logs/access_log_test.go | 0 {pkg/core => core}/logs/alils/alils.go | 2 +- {pkg/core => core}/logs/alils/config.go | 0 {pkg/core => core}/logs/alils/log.pb.go | 0 {pkg/core => core}/logs/alils/log_config.go | 0 {pkg/core => core}/logs/alils/log_project.go | 0 {pkg/core => core}/logs/alils/log_store.go | 0 .../core => core}/logs/alils/machine_group.go | 0 {pkg/core => core}/logs/alils/request.go | 0 {pkg/core => core}/logs/alils/signature.go | 0 {pkg/core => core}/logs/conn.go | 0 {pkg/core => core}/logs/conn_test.go | 0 {pkg/core => core}/logs/console.go | 0 {pkg/core => core}/logs/console_test.go | 0 {pkg/core => core}/logs/es/es.go | 2 +- {pkg/core => core}/logs/es/index.go | 2 +- {pkg/core => core}/logs/es/index_test.go | 2 +- {pkg/core => core}/logs/file.go | 0 {pkg/core => core}/logs/file_test.go | 0 {pkg/core => core}/logs/formatter.go | 0 {pkg/core => core}/logs/formatter_test.go | 0 {pkg/core => core}/logs/jianliao.go | 0 {pkg/core => core}/logs/jianliao_test.go | 0 {pkg/core => core}/logs/log.go | 0 {pkg/core => core}/logs/log_msg.go | 0 {pkg/core => core}/logs/log_msg_test.go | 0 {pkg/core => core}/logs/log_test.go | 0 {pkg/core => core}/logs/logger.go | 0 {pkg/core => core}/logs/logger_test.go | 0 {pkg/core => core}/logs/multifile.go | 0 {pkg/core => core}/logs/multifile_test.go | 0 {pkg/core => core}/logs/slack.go | 0 {pkg/core => core}/logs/smtp.go | 0 {pkg/core => core}/logs/smtp_test.go | 0 {pkg/core => core}/utils/caller.go | 0 {pkg/core => core}/utils/caller_test.go | 0 {pkg/core => core}/utils/debug.go | 0 {pkg/core => core}/utils/debug_test.go | 0 {pkg/core => core}/utils/file.go | 0 {pkg/core => core}/utils/file_test.go | 0 {pkg/core => core}/utils/kv.go | 0 {pkg/core => core}/utils/kv_test.go | 0 {pkg/core => core}/utils/mail.go | 0 {pkg/core => core}/utils/mail_test.go | 0 {pkg/core => core}/utils/pagination/doc.go | 2 +- .../utils/pagination/paginator.go | 0 {pkg/core => core}/utils/pagination/utils.go | 0 {pkg/core => core}/utils/rand.go | 0 {pkg/core => core}/utils/rand_test.go | 0 {pkg/core => core}/utils/safemap.go | 0 {pkg/core => core}/utils/safemap_test.go | 0 {pkg/core => core}/utils/slice.go | 0 {pkg/core => core}/utils/slice_test.go | 0 {pkg/core => core}/utils/testdata/grepe.test | 0 {pkg/core => core}/utils/time.go | 0 {pkg/core => core}/utils/utils.go | 0 {pkg/core => core}/utils/utils_test.go | 0 {pkg/core => core}/validation/README.md | 0 {pkg/core => core}/validation/util.go | 0 {pkg/core => core}/validation/util_test.go | 0 {pkg/core => core}/validation/validation.go | 0 .../validation/validation_test.go | 0 {pkg/core => core}/validation/validators.go | 2 +- pkg/doc.go => doc.go | 2 +- githook/pre-commit | 7 -- pkg/adapter/logs/alils/alils.go | 5 - pkg/adapter/logs/es/es.go | 5 - {pkg/server => server}/web/LICENSE | 0 {pkg/server => server}/web/admin.go | 2 +- .../server => server}/web/admin_controller.go | 2 +- {pkg/server => server}/web/admin_test.go | 2 +- {pkg/server => server}/web/adminui.go | 0 {pkg/server => server}/web/beego.go | 0 {pkg/server => server}/web/captcha/LICENSE | 0 {pkg/server => server}/web/captcha/README.md | 0 {pkg/server => server}/web/captcha/captcha.go | 8 +- {pkg/server => server}/web/captcha/image.go | 0 .../web/captcha/image_test.go | 2 +- {pkg/server => server}/web/captcha/siprng.go | 0 .../web/captcha/siprng_test.go | 0 {pkg/server => server}/web/config.go | 14 +-- {pkg/server => server}/web/config_test.go | 2 +- .../web/context/acceptencoder.go | 0 .../web/context/acceptencoder_test.go | 0 {pkg/server => server}/web/context/context.go | 2 +- .../web/context/context_test.go | 0 {pkg/server => server}/web/context/input.go | 2 +- .../web/context/input_test.go | 0 {pkg/server => server}/web/context/output.go | 0 .../web/context/param/conv.go | 4 +- .../web/context/param/methodparams.go | 0 .../web/context/param/options.go | 0 .../web/context/param/parsers.go | 0 .../web/context/param/parsers_test.go | 0 .../server => server}/web/context/renderer.go | 0 .../server => server}/web/context/response.go | 3 +- {pkg/server => server}/web/controller.go | 6 +- {pkg/server => server}/web/controller_test.go | 2 +- {pkg/server => server}/web/doc.go | 2 +- {pkg/server => server}/web/error.go | 10 +- {pkg/server => server}/web/error_test.go | 0 {pkg/server => server}/web/filter.go | 2 +- .../web/filter/apiauth/apiauth.go | 4 +- .../web/filter/apiauth/apiauth_test.go | 0 .../web/filter/auth/basic.go | 4 +- .../web/filter/authz/authz.go | 4 +- .../web/filter/authz/authz_model.conf | 0 .../web/filter/authz/authz_policy.csv | 0 .../web/filter/authz/authz_test.go | 6 +- .../server => server}/web/filter/cors/cors.go | 4 +- .../web/filter/cors/cors_test.go | 4 +- .../web/filter/opentracing/filter.go | 5 +- .../web/filter/opentracing/filter_test.go | 2 +- .../web/filter/prometheus/filter.go | 20 ++-- .../web/filter/prometheus/filter_test.go | 2 +- .../web/filter_chain_test.go | 2 +- {pkg/server => server}/web/filter_test.go | 2 +- {pkg/server => server}/web/flash.go | 0 {pkg/server => server}/web/flash_test.go | 0 {pkg/server => server}/web/fs.go | 0 {pkg/server => server}/web/grace/grace.go | 0 {pkg/server => server}/web/grace/server.go | 0 {pkg/server => server}/web/hooks.go | 6 +- {pkg/server => server}/web/mime.go | 0 {pkg/server => server}/web/namespace.go | 28 ++--- {pkg/server => server}/web/namespace_test.go | 2 +- .../web/pagination/controller.go | 4 +- {pkg/server => server}/web/parser.go | 10 +- {pkg/server => server}/web/policy.go | 2 +- {pkg/server => server}/web/router.go | 8 +- {pkg/server => server}/web/router_test.go | 4 +- {pkg/server => server}/web/server.go | 8 +- {pkg/server => server}/web/server_test.go | 0 {pkg/server => server}/web/session/README.md | 0 .../web/session/couchbase/sess_couchbase.go | 2 +- .../web/session/ledis/ledis_session.go | 2 +- .../web/session/memcache/sess_memcache.go | 2 +- .../web/session/mysql/sess_mysql.go | 2 +- .../web/session/postgres/sess_postgresql.go | 2 +- .../web/session/redis/sess_redis.go | 4 +- .../web/session/redis/sess_redis_test.go | 2 +- .../session/redis_cluster/redis_cluster.go | 2 +- .../redis_sentinel/sess_redis_sentinel.go | 3 +- .../sess_redis_sentinel_test.go | 2 +- .../web/session/sess_cookie.go | 0 .../web/session/sess_cookie_test.go | 0 .../web/session/sess_file.go | 0 .../web/session/sess_file_test.go | 0 .../server => server}/web/session/sess_mem.go | 0 .../web/session/sess_mem_test.go | 0 .../web/session/sess_test.go | 0 .../web/session/sess_utils.go | 2 +- {pkg/server => server}/web/session/session.go | 0 .../web/session/ssdb/sess_ssdb.go | 2 +- {pkg/server => server}/web/staticfile.go | 4 +- {pkg/server => server}/web/staticfile_test.go | 0 {pkg/server => server}/web/statistics.go | 2 +- {pkg/server => server}/web/statistics_test.go | 0 {pkg/server => server}/web/swagger/swagger.go | 0 {pkg/server => server}/web/template.go | 4 +- {pkg/server => server}/web/template_test.go | 0 {pkg/server => server}/web/templatefunc.go | 0 .../web/templatefunc_test.go | 0 {pkg/server => server}/web/tree.go | 4 +- {pkg/server => server}/web/tree_test.go | 2 +- {pkg/server => server}/web/unregroute_test.go | 0 {pkg/task => task}/govenor_command.go | 2 +- {pkg/task => task}/governor_command_test.go | 0 {pkg/task => task}/task.go | 0 {pkg/task => task}/task_test.go | 0 431 files changed, 372 insertions(+), 545 deletions(-) rename {pkg/adapter => adapter}/admin.go (93%) rename {pkg/adapter => adapter}/app.go (98%) rename {pkg/adapter => adapter}/beego.go (95%) rename {pkg/adapter => adapter}/build_info.go (100%) rename {pkg/adapter => adapter}/cache/cache.go (100%) rename {pkg/adapter => adapter}/cache/cache_adapter.go (98%) rename {pkg/adapter => adapter}/cache/cache_test.go (100%) rename {pkg/adapter => adapter}/cache/conv.go (96%) rename {pkg/adapter => adapter}/cache/conv_test.go (100%) rename {pkg/adapter => adapter}/cache/file.go (95%) rename {pkg/adapter => adapter}/cache/memcache/memcache.go (92%) rename {pkg/adapter => adapter}/cache/memcache/memcache_test.go (98%) rename {pkg/adapter => adapter}/cache/memory.go (94%) rename {pkg/adapter => adapter}/cache/redis/redis.go (92%) rename {pkg/adapter => adapter}/cache/redis/redis_test.go (98%) rename {pkg/adapter => adapter}/cache/ssdb/ssdb.go (68%) rename {pkg/adapter => adapter}/cache/ssdb/ssdb_test.go (98%) rename {pkg/adapter => adapter}/config.go (97%) rename {pkg/adapter => adapter}/config/adapter.go (99%) rename {pkg/adapter => adapter}/config/config.go (99%) rename {pkg/adapter => adapter}/config/config_test.go (100%) rename {pkg/adapter => adapter}/config/env/env.go (97%) rename {pkg/adapter => adapter}/config/env/env_test.go (100%) rename {pkg/adapter => adapter}/config/fake.go (94%) rename {pkg/adapter => adapter}/config/ini_test.go (100%) rename {pkg/adapter => adapter}/config/json.go (92%) rename {pkg/adapter => adapter}/config/json_test.go (100%) rename {pkg/adapter => adapter}/config/xml/xml.go (95%) rename {pkg/adapter => adapter}/config/xml/xml_test.go (98%) rename {pkg/adapter => adapter}/config/yaml/yaml.go (95%) rename {pkg/adapter => adapter}/config/yaml/yaml_test.go (98%) rename {pkg/adapter => adapter}/context/acceptencoder.go (96%) rename {pkg/adapter => adapter}/context/context.go (98%) rename {pkg/adapter => adapter}/context/input.go (99%) rename {pkg/adapter => adapter}/context/output.go (99%) rename {pkg/adapter => adapter}/context/renderer.go (67%) rename {pkg/adapter => adapter}/context/response.go (100%) rename {pkg/adapter => adapter}/controller.go (98%) rename {pkg/adapter => adapter}/doc.go (100%) rename {pkg/adapter => adapter}/error.go (96%) rename {pkg/adapter => adapter}/filter.go (90%) rename {pkg/adapter => adapter}/flash.go (97%) rename {pkg/adapter => adapter}/fs.go (96%) rename {pkg/adapter => adapter}/grace/grace.go (98%) rename {pkg/adapter => adapter}/grace/server.go (97%) rename {pkg/adapter => adapter}/httplib/httplib.go (99%) rename {pkg/adapter => adapter}/httplib/httplib_test.go (100%) rename {pkg/adapter => adapter}/log.go (97%) rename {pkg/adapter => adapter}/logs/accesslog.go (95%) create mode 100644 adapter/logs/alils/alils.go create mode 100644 adapter/logs/es/es.go rename {pkg/adapter => adapter}/logs/log.go (99%) rename {pkg/adapter => adapter}/logs/log_adapter.go (97%) rename {pkg/adapter => adapter}/logs/logger.go (96%) rename {pkg/adapter => adapter}/logs/logger_test.go (100%) rename {pkg/adapter => adapter}/metric/prometheus.go (87%) rename {pkg/adapter => adapter}/metric/prometheus_test.go (96%) rename {pkg/adapter => adapter}/migration/ddl.go (99%) rename {pkg/adapter => adapter}/migration/doc.go (100%) rename {pkg/adapter => adapter}/migration/migration.go (98%) rename {pkg/adapter => adapter}/namespace.go (98%) rename {pkg/adapter => adapter}/orm/cmd.go (95%) rename {pkg/adapter => adapter}/orm/db.go (94%) rename {pkg/adapter => adapter}/orm/db_alias.go (98%) rename {pkg/adapter => adapter}/orm/models.go (94%) rename {pkg/adapter => adapter}/orm/models_boot.go (96%) rename {pkg/adapter => adapter}/orm/models_fields.go (99%) rename {pkg/adapter => adapter}/orm/orm.go (98%) rename {pkg/adapter => adapter}/orm/orm_conds.go (98%) rename {pkg/adapter => adapter}/orm/orm_log.go (95%) rename {pkg/adapter => adapter}/orm/orm_queryset.go (95%) rename {pkg/adapter => adapter}/orm/qb.go (95%) rename {pkg/adapter => adapter}/orm/qb_mysql.go (99%) rename {pkg/adapter => adapter}/orm/qb_tidb.go (99%) rename {pkg/adapter => adapter}/orm/query_setter_adapter.go (95%) rename {pkg/adapter => adapter}/orm/types.go (99%) rename {pkg/adapter => adapter}/orm/utils.go (99%) rename {pkg/adapter => adapter}/orm/utils_test.go (100%) rename {pkg/adapter => adapter}/plugins/apiauth/apiauth.go (92%) rename {pkg/adapter => adapter}/plugins/apiauth/apiauth_test.go (100%) rename {pkg/adapter => adapter}/plugins/auth/basic.go (92%) rename {pkg/adapter => adapter}/plugins/authz/authz.go (91%) rename {pkg/adapter => adapter}/plugins/authz/authz_model.conf (100%) rename {pkg/adapter => adapter}/plugins/authz/authz_policy.csv (100%) rename {pkg/adapter => adapter}/plugins/authz/authz_test.go (96%) rename {pkg/adapter => adapter}/plugins/cors/cors.go (91%) rename {pkg/adapter => adapter}/policy.go (91%) rename {pkg/adapter => adapter}/router.go (98%) rename {pkg/adapter => adapter}/session/couchbase/sess_couchbase.go (97%) rename {pkg/adapter => adapter}/session/ledis/ledis_session.go (95%) rename {pkg/adapter => adapter}/session/memcache/sess_memcache.go (97%) rename {pkg/adapter => adapter}/session/mysql/sess_mysql.go (97%) rename {pkg/adapter => adapter}/session/postgres/sess_postgresql.go (97%) rename {pkg/adapter => adapter}/session/provider_adapter.go (98%) rename {pkg/adapter => adapter}/session/redis/sess_redis.go (97%) rename {pkg/adapter => adapter}/session/redis_cluster/redis_cluster.go (97%) rename {pkg/adapter => adapter}/session/redis_sentinel/sess_redis_sentinel.go (97%) rename {pkg/adapter => adapter}/session/redis_sentinel/sess_redis_sentinel_test.go (97%) rename {pkg/adapter => adapter}/session/sess_cookie.go (98%) rename {pkg/adapter => adapter}/session/sess_cookie_test.go (100%) rename {pkg/adapter => adapter}/session/sess_file.go (98%) rename {pkg/adapter => adapter}/session/sess_file_test.go (100%) rename {pkg/adapter => adapter}/session/sess_mem.go (98%) rename {pkg/adapter => adapter}/session/sess_mem_test.go (100%) rename {pkg/adapter => adapter}/session/sess_test.go (100%) rename {pkg/adapter => adapter}/session/sess_utils.go (94%) rename {pkg/adapter => adapter}/session/session.go (99%) rename {pkg/adapter => adapter}/session/ssdb/sess_ssdb.go (95%) rename {pkg/adapter => adapter}/session/store_adapter.go (97%) rename {pkg/adapter => adapter}/swagger/swagger.go (98%) rename {pkg/adapter => adapter}/template.go (98%) rename {pkg/adapter => adapter}/templatefunc.go (98%) rename {pkg/adapter => adapter}/templatefunc_test.go (100%) rename {pkg/adapter => adapter}/testing/client.go (96%) rename {pkg/adapter => adapter}/toolbox/healthcheck.go (96%) rename {pkg/adapter => adapter}/toolbox/profile.go (96%) rename {pkg/adapter => adapter}/toolbox/profile_test.go (100%) rename {pkg/adapter => adapter}/toolbox/statistics.go (97%) rename {pkg/adapter => adapter}/toolbox/statistics_test.go (100%) rename {pkg/adapter => adapter}/toolbox/task.go (99%) rename {pkg/adapter => adapter}/toolbox/task_test.go (100%) rename {pkg/adapter => adapter}/tree.go (90%) rename {pkg/adapter => adapter}/tree_test.go (99%) rename {pkg/adapter => adapter}/utils/caller.go (94%) rename {pkg/adapter => adapter}/utils/caller_test.go (100%) rename {pkg/adapter => adapter}/utils/captcha/LICENSE (100%) rename {pkg/adapter => adapter}/utils/captcha/README.md (100%) rename {pkg/adapter => adapter}/utils/captcha/captcha.go (93%) rename {pkg/adapter => adapter}/utils/captcha/image.go (95%) rename {pkg/adapter => adapter}/utils/captcha/image_test.go (96%) rename {pkg/adapter => adapter}/utils/debug.go (95%) rename {pkg/adapter => adapter}/utils/debug_test.go (100%) rename {pkg/adapter => adapter}/utils/file.go (97%) rename {pkg/adapter => adapter}/utils/mail.go (97%) rename {pkg/adapter => adapter}/utils/mail_test.go (100%) rename {pkg/adapter => adapter}/utils/pagination/controller.go (84%) rename {pkg/adapter => adapter}/utils/pagination/doc.go (100%) rename {pkg/adapter => adapter}/utils/pagination/paginator.go (98%) rename {pkg/adapter => adapter}/utils/rand.go (94%) rename {pkg/adapter => adapter}/utils/rand_test.go (100%) rename {pkg/adapter => adapter}/utils/safemap.go (97%) rename {pkg/adapter => adapter}/utils/safemap_test.go (100%) rename {pkg/adapter => adapter}/utils/slice.go (98%) rename {pkg/adapter => adapter}/utils/slice_test.go (100%) rename {pkg/adapter => adapter}/utils/utils.go (76%) rename {pkg/adapter => adapter}/validation/util.go (97%) rename {pkg/adapter => adapter}/validation/validation.go (99%) rename {pkg/adapter => adapter}/validation/validation_test.go (100%) rename {pkg/adapter => adapter}/validation/validators.go (99%) delete mode 100755 build/gobuild-sample.sh delete mode 100755 build/report_build_info.sh rename pkg/build_info.go => build_info.go (98%) rename {pkg/client => client}/cache/README.md (100%) rename {pkg/client => client}/cache/cache.go (100%) rename {pkg/client => client}/cache/cache_test.go (100%) rename {pkg/client => client}/cache/conv.go (100%) rename {pkg/client => client}/cache/conv_test.go (100%) rename {pkg/client => client}/cache/file.go (100%) rename {pkg/client => client}/cache/memcache/memcache.go (98%) rename {pkg/client => client}/cache/memcache/memcache_test.go (98%) rename {pkg/client => client}/cache/memory.go (100%) rename {pkg/client => client}/cache/redis/redis.go (99%) rename {pkg/client => client}/cache/redis/redis_test.go (98%) rename {pkg/client => client}/cache/ssdb/ssdb.go (99%) rename {pkg/client => client}/cache/ssdb/ssdb_test.go (98%) rename {pkg/client => client}/httplib/README.md (100%) rename {pkg/client => client}/httplib/filter.go (100%) rename {pkg/client => client}/httplib/filter/opentracing/filter.go (98%) rename {pkg/client => client}/httplib/filter/opentracing/filter_test.go (96%) rename {pkg/client => client}/httplib/filter/prometheus/filter.go (97%) rename {pkg/client => client}/httplib/filter/prometheus/filter_test.go (96%) rename {pkg/client => client}/httplib/httplib.go (100%) rename {pkg/client => client}/httplib/httplib_test.go (100%) rename {pkg/client => client}/httplib/testing/client.go (97%) rename {pkg/client => client}/orm/README.md (100%) rename {pkg/client => client}/orm/cmd.go (100%) rename {pkg/client => client}/orm/cmd_utils.go (100%) rename {pkg/client => client}/orm/db.go (99%) rename {pkg/client => client}/orm/db_alias.go (100%) rename {pkg/client => client}/orm/db_alias_test.go (100%) rename {pkg/client => client}/orm/db_mysql.go (100%) rename {pkg/client => client}/orm/db_oracle.go (98%) rename {pkg/client => client}/orm/db_postgres.go (100%) rename {pkg/client => client}/orm/db_sqlite.go (99%) rename {pkg/client => client}/orm/db_tables.go (100%) rename {pkg/client => client}/orm/db_tidb.go (100%) rename {pkg/client => client}/orm/db_utils.go (100%) rename {pkg/client => client}/orm/do_nothing_orm.go (99%) rename {pkg/client => client}/orm/do_nothing_orm_test.go (100%) rename {pkg/client => client}/orm/filter.go (100%) rename {pkg/client => client}/orm/filter/bean/default_value_filter.go (97%) rename {pkg/client => client}/orm/filter/bean/default_value_filter_test.go (97%) rename {pkg/client => client}/orm/filter/opentracing/filter.go (98%) rename {pkg/client => client}/orm/filter/opentracing/filter_test.go (96%) rename {pkg/client => client}/orm/filter/prometheus/filter.go (98%) rename {pkg/client => client}/orm/filter/prometheus/filter_test.go (97%) rename {pkg/client => client}/orm/filter_orm_decorator.go (99%) rename {pkg/client => client}/orm/filter_orm_decorator_test.go (99%) rename {pkg/client => client}/orm/filter_test.go (100%) rename {pkg/client => client}/orm/hints/db_hints.go (98%) rename {pkg/client => client}/orm/hints/db_hints_test.go (100%) rename {pkg/client => client}/orm/invocation.go (100%) rename {pkg/client => client}/orm/migration/ddl.go (99%) rename {pkg/client => client}/orm/migration/doc.go (100%) rename {pkg/client => client}/orm/migration/migration.go (98%) rename {pkg/client => client}/orm/model_utils_test.go (100%) rename {pkg/client => client}/orm/models.go (100%) rename {pkg/client => client}/orm/models_boot.go (100%) rename {pkg/client => client}/orm/models_fields.go (100%) rename {pkg/client => client}/orm/models_info_f.go (100%) rename {pkg/client => client}/orm/models_info_m.go (100%) rename {pkg/client => client}/orm/models_test.go (97%) rename {pkg/client => client}/orm/models_utils.go (100%) rename {pkg/client => client}/orm/models_utils_test.go (100%) rename {pkg/client => client}/orm/orm.go (98%) rename {pkg/client => client}/orm/orm_conds.go (100%) rename {pkg/client => client}/orm/orm_log.go (100%) rename {pkg/client => client}/orm/orm_object.go (100%) rename {pkg/client => client}/orm/orm_querym2m.go (100%) rename {pkg/client => client}/orm/orm_queryset.go (99%) rename {pkg/client => client}/orm/orm_raw.go (100%) rename {pkg/client => client}/orm/orm_test.go (99%) rename {pkg/client => client}/orm/qb.go (100%) rename {pkg/client => client}/orm/qb_mysql.go (100%) rename {pkg/client => client}/orm/qb_postgres.go (100%) rename {pkg/client => client}/orm/qb_tidb.go (100%) rename {pkg/client => client}/orm/types.go (99%) rename {pkg/client => client}/orm/utils.go (100%) rename {pkg/client => client}/orm/utils_test.go (100%) rename {pkg/core => core}/bean/context.go (100%) rename {pkg/core => core}/bean/doc.go (100%) rename {pkg/core => core}/bean/factory.go (100%) rename {pkg/core => core}/bean/metadata.go (100%) rename {pkg/core => core}/bean/tag_auto_wire_bean_factory.go (99%) rename {pkg/core => core}/bean/tag_auto_wire_bean_factory_test.go (100%) rename {pkg/core => core}/bean/time_type_adapter.go (100%) rename {pkg/core => core}/bean/time_type_adapter_test.go (100%) rename {pkg/core => core}/bean/type_adapter.go (100%) rename {pkg/core => core}/config/base_config_test.go (100%) rename {pkg/core => core}/config/config.go (100%) rename {pkg/core => core}/config/config_test.go (100%) rename {pkg/core => core}/config/env/env.go (98%) rename {pkg/core => core}/config/env/env_test.go (100%) rename {pkg/core => core}/config/etcd/config.go (98%) rename {pkg/core => core}/config/etcd/config_test.go (100%) rename {pkg/core => core}/config/fake.go (100%) rename {pkg/core => core}/config/ini.go (100%) rename {pkg/core => core}/config/ini_test.go (100%) rename {pkg/core => core}/config/json/json.go (98%) rename {pkg/core => core}/config/json/json_test.go (99%) rename {pkg/core => core}/config/xml/xml.go (98%) rename {pkg/core => core}/config/xml/xml_test.go (98%) rename {pkg/core => core}/config/yaml/yaml.go (99%) rename {pkg/core => core}/config/yaml/yaml_test.go (98%) rename {pkg/core => core}/governor/command.go (100%) rename {pkg/core => core}/governor/healthcheck.go (100%) rename {pkg/core => core}/governor/profile.go (99%) rename {pkg/core => core}/governor/profile_test.go (100%) rename {pkg/core => core}/logs/README.md (100%) rename {pkg/core => core}/logs/access_log.go (100%) rename {pkg/core => core}/logs/access_log_test.go (100%) rename {pkg/core => core}/logs/alils/alils.go (98%) rename {pkg/core => core}/logs/alils/config.go (100%) rename {pkg/core => core}/logs/alils/log.pb.go (100%) rename {pkg/core => core}/logs/alils/log_config.go (100%) rename {pkg/core => core}/logs/alils/log_project.go (100%) rename {pkg/core => core}/logs/alils/log_store.go (100%) rename {pkg/core => core}/logs/alils/machine_group.go (100%) rename {pkg/core => core}/logs/alils/request.go (100%) rename {pkg/core => core}/logs/alils/signature.go (100%) rename {pkg/core => core}/logs/conn.go (100%) rename {pkg/core => core}/logs/conn_test.go (100%) rename {pkg/core => core}/logs/console.go (100%) rename {pkg/core => core}/logs/console_test.go (100%) rename {pkg/core => core}/logs/es/es.go (98%) rename {pkg/core => core}/logs/es/index.go (96%) rename {pkg/core => core}/logs/es/index_test.go (95%) rename {pkg/core => core}/logs/file.go (100%) rename {pkg/core => core}/logs/file_test.go (100%) rename {pkg/core => core}/logs/formatter.go (100%) rename {pkg/core => core}/logs/formatter_test.go (100%) rename {pkg/core => core}/logs/jianliao.go (100%) rename {pkg/core => core}/logs/jianliao_test.go (100%) rename {pkg/core => core}/logs/log.go (100%) rename {pkg/core => core}/logs/log_msg.go (100%) rename {pkg/core => core}/logs/log_msg_test.go (100%) rename {pkg/core => core}/logs/log_test.go (100%) rename {pkg/core => core}/logs/logger.go (100%) rename {pkg/core => core}/logs/logger_test.go (100%) rename {pkg/core => core}/logs/multifile.go (100%) rename {pkg/core => core}/logs/multifile_test.go (100%) rename {pkg/core => core}/logs/slack.go (100%) rename {pkg/core => core}/logs/smtp.go (100%) rename {pkg/core => core}/logs/smtp_test.go (100%) rename {pkg/core => core}/utils/caller.go (100%) rename {pkg/core => core}/utils/caller_test.go (100%) rename {pkg/core => core}/utils/debug.go (100%) rename {pkg/core => core}/utils/debug_test.go (100%) rename {pkg/core => core}/utils/file.go (100%) rename {pkg/core => core}/utils/file_test.go (100%) rename {pkg/core => core}/utils/kv.go (100%) rename {pkg/core => core}/utils/kv_test.go (100%) rename {pkg/core => core}/utils/mail.go (100%) rename {pkg/core => core}/utils/mail_test.go (100%) rename {pkg/core => core}/utils/pagination/doc.go (96%) rename {pkg/core => core}/utils/pagination/paginator.go (100%) rename {pkg/core => core}/utils/pagination/utils.go (100%) rename {pkg/core => core}/utils/rand.go (100%) rename {pkg/core => core}/utils/rand_test.go (100%) rename {pkg/core => core}/utils/safemap.go (100%) rename {pkg/core => core}/utils/safemap_test.go (100%) rename {pkg/core => core}/utils/slice.go (100%) rename {pkg/core => core}/utils/slice_test.go (100%) rename {pkg/core => core}/utils/testdata/grepe.test (100%) rename {pkg/core => core}/utils/time.go (100%) rename {pkg/core => core}/utils/utils.go (100%) rename {pkg/core => core}/utils/utils_test.go (100%) rename {pkg/core => core}/validation/README.md (100%) rename {pkg/core => core}/validation/util.go (100%) rename {pkg/core => core}/validation/util_test.go (100%) rename {pkg/core => core}/validation/validation.go (100%) rename {pkg/core => core}/validation/validation_test.go (100%) rename {pkg/core => core}/validation/validators.go (99%) rename pkg/doc.go => doc.go (97%) delete mode 100755 githook/pre-commit delete mode 100644 pkg/adapter/logs/alils/alils.go delete mode 100644 pkg/adapter/logs/es/es.go rename {pkg/server => server}/web/LICENSE (100%) rename {pkg/server => server}/web/admin.go (98%) rename {pkg/server => server}/web/admin_controller.go (99%) rename {pkg/server => server}/web/admin_test.go (99%) rename {pkg/server => server}/web/adminui.go (100%) rename {pkg/server => server}/web/beego.go (100%) rename {pkg/server => server}/web/captcha/LICENSE (100%) rename {pkg/server => server}/web/captcha/README.md (100%) rename {pkg/server => server}/web/captcha/captcha.go (97%) rename {pkg/server => server}/web/captcha/image.go (100%) rename {pkg/server => server}/web/captcha/image_test.go (96%) rename {pkg/server => server}/web/captcha/siprng.go (100%) rename {pkg/server => server}/web/captcha/siprng_test.go (100%) rename {pkg/server => server}/web/config.go (97%) rename {pkg/server => server}/web/config_test.go (98%) rename {pkg/server => server}/web/context/acceptencoder.go (100%) rename {pkg/server => server}/web/context/acceptencoder_test.go (100%) rename {pkg/server => server}/web/context/context.go (99%) rename {pkg/server => server}/web/context/context_test.go (100%) rename {pkg/server => server}/web/context/input.go (99%) rename {pkg/server => server}/web/context/input_test.go (100%) rename {pkg/server => server}/web/context/output.go (100%) rename {pkg/server => server}/web/context/param/conv.go (95%) rename {pkg/server => server}/web/context/param/methodparams.go (100%) rename {pkg/server => server}/web/context/param/options.go (100%) rename {pkg/server => server}/web/context/param/parsers.go (100%) rename {pkg/server => server}/web/context/param/parsers_test.go (100%) rename {pkg/server => server}/web/context/renderer.go (100%) rename {pkg/server => server}/web/context/response.go (99%) rename {pkg/server => server}/web/controller.go (99%) rename {pkg/server => server}/web/controller_test.go (98%) rename {pkg/server => server}/web/doc.go (91%) rename {pkg/server => server}/web/error.go (98%) rename {pkg/server => server}/web/error_test.go (100%) rename {pkg/server => server}/web/filter.go (98%) rename {pkg/server => server}/web/filter/apiauth/apiauth.go (97%) rename {pkg/server => server}/web/filter/apiauth/apiauth_test.go (100%) rename {pkg/server => server}/web/filter/auth/basic.go (97%) rename {pkg/server => server}/web/filter/authz/authz.go (96%) rename {pkg/server => server}/web/filter/authz/authz_model.conf (100%) rename {pkg/server => server}/web/filter/authz/authz_policy.csv (100%) rename {pkg/server => server}/web/filter/authz/authz_test.go (96%) rename {pkg/server => server}/web/filter/cors/cors.go (98%) rename {pkg/server => server}/web/filter/cors/cors_test.go (98%) rename {pkg/server => server}/web/filter/opentracing/filter.go (96%) rename {pkg/server => server}/web/filter/opentracing/filter_test.go (96%) rename {pkg/server => server}/web/filter/prometheus/filter.go (84%) rename {pkg/server => server}/web/filter/prometheus/filter_test.go (95%) rename {pkg/server => server}/web/filter_chain_test.go (95%) rename {pkg/server => server}/web/filter_test.go (97%) rename {pkg/server => server}/web/flash.go (100%) rename {pkg/server => server}/web/flash_test.go (100%) rename {pkg/server => server}/web/fs.go (100%) rename {pkg/server => server}/web/grace/grace.go (100%) rename {pkg/server => server}/web/grace/server.go (100%) rename {pkg/server => server}/web/hooks.go (95%) rename {pkg/server => server}/web/mime.go (100%) rename {pkg/server => server}/web/namespace.go (91%) rename {pkg/server => server}/web/namespace_test.go (98%) rename {pkg/server => server}/web/pagination/controller.go (90%) rename {pkg/server => server}/web/parser.go (98%) rename {pkg/server => server}/web/policy.go (98%) rename {pkg/server => server}/web/router.go (99%) rename {pkg/server => server}/web/router_test.go (99%) rename {pkg/server => server}/web/server.go (99%) rename {pkg/server => server}/web/server_test.go (100%) rename {pkg/server => server}/web/session/README.md (100%) rename {pkg/server => server}/web/session/couchbase/sess_couchbase.go (99%) rename {pkg/server => server}/web/session/ledis/ledis_session.go (98%) rename {pkg/server => server}/web/session/memcache/sess_memcache.go (99%) rename {pkg/server => server}/web/session/mysql/sess_mysql.go (99%) rename {pkg/server => server}/web/session/postgres/sess_postgresql.go (99%) rename {pkg/server => server}/web/session/redis/sess_redis.go (99%) rename {pkg/server => server}/web/session/redis/sess_redis_test.go (97%) rename {pkg/server => server}/web/session/redis_cluster/redis_cluster.go (99%) rename {pkg/server => server}/web/session/redis_sentinel/sess_redis_sentinel.go (99%) rename {pkg/server => server}/web/session/redis_sentinel/sess_redis_sentinel_test.go (97%) rename {pkg/server => server}/web/session/sess_cookie.go (100%) rename {pkg/server => server}/web/session/sess_cookie_test.go (100%) rename {pkg/server => server}/web/session/sess_file.go (100%) rename {pkg/server => server}/web/session/sess_file_test.go (100%) rename {pkg/server => server}/web/session/sess_mem.go (100%) rename {pkg/server => server}/web/session/sess_mem_test.go (100%) rename {pkg/server => server}/web/session/sess_test.go (100%) rename {pkg/server => server}/web/session/sess_utils.go (99%) rename {pkg/server => server}/web/session/session.go (100%) rename {pkg/server => server}/web/session/ssdb/sess_ssdb.go (98%) rename {pkg/server => server}/web/staticfile.go (98%) rename {pkg/server => server}/web/staticfile_test.go (100%) rename {pkg/server => server}/web/statistics.go (99%) rename {pkg/server => server}/web/statistics_test.go (100%) rename {pkg/server => server}/web/swagger/swagger.go (100%) rename {pkg/server => server}/web/template.go (99%) rename {pkg/server => server}/web/template_test.go (100%) rename {pkg/server => server}/web/templatefunc.go (100%) rename {pkg/server => server}/web/templatefunc_test.go (100%) rename {pkg/server => server}/web/tree.go (99%) rename {pkg/server => server}/web/tree_test.go (99%) rename {pkg/server => server}/web/unregroute_test.go (100%) rename {pkg/task => task}/govenor_command.go (97%) rename {pkg/task => task}/governor_command_test.go (100%) rename {pkg/task => task}/task.go (100%) rename {pkg/task => task}/task_test.go (100%) diff --git a/.travis.yml b/.travis.yml index 67efe057..973b40ef 100644 --- a/.travis.yml +++ b/.travis.yml @@ -96,7 +96,7 @@ after_script: - rm -rf ./res/var/* script: - go test ./... - - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./pkg + - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./ - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s diff --git a/pkg/adapter/admin.go b/adapter/admin.go similarity index 93% rename from pkg/adapter/admin.go rename to adapter/admin.go index 5ba78511..e555f59e 100644 --- a/pkg/adapter/admin.go +++ b/adapter/admin.go @@ -17,8 +17,8 @@ package adapter import ( "time" - _ "github.com/astaxie/beego/pkg/core/governor" - "github.com/astaxie/beego/pkg/server/web" + _ "github.com/astaxie/beego/core/governor" + "github.com/astaxie/beego/server/web" ) // FilterMonitorFunc is default monitor filter when admin module is enable. diff --git a/pkg/adapter/app.go b/adapter/app.go similarity index 98% rename from pkg/adapter/app.go rename to adapter/app.go index 10ffa96a..e20cd9d2 100644 --- a/pkg/adapter/app.go +++ b/adapter/app.go @@ -17,9 +17,9 @@ package adapter import ( "net/http" - context2 "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + context2 "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) var ( diff --git a/pkg/adapter/beego.go b/adapter/beego.go similarity index 95% rename from pkg/adapter/beego.go rename to adapter/beego.go index eb7be3f6..bbe37db8 100644 --- a/pkg/adapter/beego.go +++ b/adapter/beego.go @@ -15,14 +15,14 @@ package adapter import ( - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego" + "github.com/astaxie/beego/server/web" ) const ( // VERSION represent beego web framework version. - VERSION = pkg.VERSION + VERSION = beego.VERSION // DEV is for develop DEV = web.DEV diff --git a/pkg/adapter/build_info.go b/adapter/build_info.go similarity index 100% rename from pkg/adapter/build_info.go rename to adapter/build_info.go diff --git a/pkg/adapter/cache/cache.go b/adapter/cache/cache.go similarity index 100% rename from pkg/adapter/cache/cache.go rename to adapter/cache/cache.go diff --git a/pkg/adapter/cache/cache_adapter.go b/adapter/cache/cache_adapter.go similarity index 98% rename from pkg/adapter/cache/cache_adapter.go rename to adapter/cache/cache_adapter.go index f1441ac8..3bfd0bf8 100644 --- a/pkg/adapter/cache/cache_adapter.go +++ b/adapter/cache/cache_adapter.go @@ -18,7 +18,7 @@ import ( "context" "time" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) type newToOldCacheAdapter struct { diff --git a/pkg/adapter/cache/cache_test.go b/adapter/cache/cache_test.go similarity index 100% rename from pkg/adapter/cache/cache_test.go rename to adapter/cache/cache_test.go diff --git a/pkg/adapter/cache/conv.go b/adapter/cache/conv.go similarity index 96% rename from pkg/adapter/cache/conv.go rename to adapter/cache/conv.go index d46cc31c..18b8a255 100644 --- a/pkg/adapter/cache/conv.go +++ b/adapter/cache/conv.go @@ -15,7 +15,7 @@ package cache import ( - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) // GetString convert interface to string. diff --git a/pkg/adapter/cache/conv_test.go b/adapter/cache/conv_test.go similarity index 100% rename from pkg/adapter/cache/conv_test.go rename to adapter/cache/conv_test.go diff --git a/pkg/adapter/cache/file.go b/adapter/cache/file.go similarity index 95% rename from pkg/adapter/cache/file.go rename to adapter/cache/file.go index 04598d27..74eb980a 100644 --- a/pkg/adapter/cache/file.go +++ b/adapter/cache/file.go @@ -15,7 +15,7 @@ package cache import ( - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) // NewFileCache Create new file cache with no config. diff --git a/pkg/adapter/cache/memcache/memcache.go b/adapter/cache/memcache/memcache.go similarity index 92% rename from pkg/adapter/cache/memcache/memcache.go rename to adapter/cache/memcache/memcache.go index f2acffca..b4da1bfe 100644 --- a/pkg/adapter/cache/memcache/memcache.go +++ b/adapter/cache/memcache/memcache.go @@ -30,8 +30,8 @@ package memcache import ( - "github.com/astaxie/beego/pkg/adapter/cache" - "github.com/astaxie/beego/pkg/client/cache/memcache" + "github.com/astaxie/beego/adapter/cache" + "github.com/astaxie/beego/client/cache/memcache" ) // NewMemCache create new memcache adapter. diff --git a/pkg/adapter/cache/memcache/memcache_test.go b/adapter/cache/memcache/memcache_test.go similarity index 98% rename from pkg/adapter/cache/memcache/memcache_test.go rename to adapter/cache/memcache/memcache_test.go index e6e605a4..b9b6dc6b 100644 --- a/pkg/adapter/cache/memcache/memcache_test.go +++ b/adapter/cache/memcache/memcache_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/adapter/cache" ) func TestMemcacheCache(t *testing.T) { diff --git a/pkg/adapter/cache/memory.go b/adapter/cache/memory.go similarity index 94% rename from pkg/adapter/cache/memory.go rename to adapter/cache/memory.go index 2d734bc0..cf6e3992 100644 --- a/pkg/adapter/cache/memory.go +++ b/adapter/cache/memory.go @@ -15,7 +15,7 @@ package cache import ( - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) // NewMemoryCache returns a new MemoryCache. diff --git a/pkg/adapter/cache/redis/redis.go b/adapter/cache/redis/redis.go similarity index 92% rename from pkg/adapter/cache/redis/redis.go rename to adapter/cache/redis/redis.go index 3aeb8691..3562057d 100644 --- a/pkg/adapter/cache/redis/redis.go +++ b/adapter/cache/redis/redis.go @@ -30,8 +30,8 @@ package redis import ( - "github.com/astaxie/beego/pkg/adapter/cache" - redis2 "github.com/astaxie/beego/pkg/client/cache/redis" + "github.com/astaxie/beego/adapter/cache" + redis2 "github.com/astaxie/beego/client/cache/redis" ) var ( diff --git a/pkg/adapter/cache/redis/redis_test.go b/adapter/cache/redis/redis_test.go similarity index 98% rename from pkg/adapter/cache/redis/redis_test.go rename to adapter/cache/redis/redis_test.go index 165ad0a7..7ae12197 100644 --- a/pkg/adapter/cache/redis/redis_test.go +++ b/adapter/cache/redis/redis_test.go @@ -22,7 +22,7 @@ import ( "github.com/gomodule/redigo/redis" - "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/adapter/cache" ) func TestRedisCache(t *testing.T) { diff --git a/pkg/adapter/cache/ssdb/ssdb.go b/adapter/cache/ssdb/ssdb.go similarity index 68% rename from pkg/adapter/cache/ssdb/ssdb.go rename to adapter/cache/ssdb/ssdb.go index 9a252b55..df552043 100644 --- a/pkg/adapter/cache/ssdb/ssdb.go +++ b/adapter/cache/ssdb/ssdb.go @@ -1,8 +1,8 @@ package ssdb import ( - "github.com/astaxie/beego/pkg/adapter/cache" - ssdb2 "github.com/astaxie/beego/pkg/client/cache/ssdb" + "github.com/astaxie/beego/adapter/cache" + ssdb2 "github.com/astaxie/beego/client/cache/ssdb" ) // NewSsdbCache create new ssdb adapter. diff --git a/pkg/adapter/cache/ssdb/ssdb_test.go b/adapter/cache/ssdb/ssdb_test.go similarity index 98% rename from pkg/adapter/cache/ssdb/ssdb_test.go rename to adapter/cache/ssdb/ssdb_test.go index 0f9dabba..080167cd 100644 --- a/pkg/adapter/cache/ssdb/ssdb_test.go +++ b/adapter/cache/ssdb/ssdb_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/adapter/cache" ) func TestSsdbcacheCache(t *testing.T) { diff --git a/pkg/adapter/config.go b/adapter/config.go similarity index 97% rename from pkg/adapter/config.go rename to adapter/config.go index 3975f5eb..46f965ee 100644 --- a/pkg/adapter/config.go +++ b/adapter/config.go @@ -17,9 +17,9 @@ package adapter import ( context2 "context" - "github.com/astaxie/beego/pkg/adapter/session" - newCfg "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/adapter/session" + newCfg "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/server/web" ) // Config is the main struct for BConfig diff --git a/pkg/adapter/config/adapter.go b/adapter/config/adapter.go similarity index 99% rename from pkg/adapter/config/adapter.go rename to adapter/config/adapter.go index 8506228f..6dc538ea 100644 --- a/pkg/adapter/config/adapter.go +++ b/adapter/config/adapter.go @@ -19,7 +19,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) type newToOldConfigerAdapter struct { diff --git a/pkg/adapter/config/config.go b/adapter/config/config.go similarity index 99% rename from pkg/adapter/config/config.go rename to adapter/config/config.go index 821379f4..703555cd 100644 --- a/pkg/adapter/config/config.go +++ b/adapter/config/config.go @@ -41,7 +41,7 @@ package config import ( - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) // Configer defines how to get and set value from configuration raw data. diff --git a/pkg/adapter/config/config_test.go b/adapter/config/config_test.go similarity index 100% rename from pkg/adapter/config/config_test.go rename to adapter/config/config_test.go diff --git a/pkg/adapter/config/env/env.go b/adapter/config/env/env.go similarity index 97% rename from pkg/adapter/config/env/env.go rename to adapter/config/env/env.go index bac80576..839c60c1 100644 --- a/pkg/adapter/config/env/env.go +++ b/adapter/config/env/env.go @@ -17,7 +17,7 @@ package env import ( - "github.com/astaxie/beego/pkg/core/config/env" + "github.com/astaxie/beego/core/config/env" ) // Get returns a value by key. diff --git a/pkg/adapter/config/env/env_test.go b/adapter/config/env/env_test.go similarity index 100% rename from pkg/adapter/config/env/env_test.go rename to adapter/config/env/env_test.go diff --git a/pkg/adapter/config/fake.go b/adapter/config/fake.go similarity index 94% rename from pkg/adapter/config/fake.go rename to adapter/config/fake.go index acbd52e5..050f0252 100644 --- a/pkg/adapter/config/fake.go +++ b/adapter/config/fake.go @@ -15,7 +15,7 @@ package config import ( - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) // NewFakeConfig return a fake Configer diff --git a/pkg/adapter/config/ini_test.go b/adapter/config/ini_test.go similarity index 100% rename from pkg/adapter/config/ini_test.go rename to adapter/config/ini_test.go diff --git a/pkg/adapter/config/json.go b/adapter/config/json.go similarity index 92% rename from pkg/adapter/config/json.go rename to adapter/config/json.go index 69c87568..d77e6146 100644 --- a/pkg/adapter/config/json.go +++ b/adapter/config/json.go @@ -15,5 +15,5 @@ package config import ( - _ "github.com/astaxie/beego/pkg/core/config/json" + _ "github.com/astaxie/beego/core/config/json" ) diff --git a/pkg/adapter/config/json_test.go b/adapter/config/json_test.go similarity index 100% rename from pkg/adapter/config/json_test.go rename to adapter/config/json_test.go diff --git a/pkg/adapter/config/xml/xml.go b/adapter/config/xml/xml.go similarity index 95% rename from pkg/adapter/config/xml/xml.go rename to adapter/config/xml/xml.go index 2744e335..28d5f44e 100644 --- a/pkg/adapter/config/xml/xml.go +++ b/adapter/config/xml/xml.go @@ -30,5 +30,5 @@ package xml import ( - _ "github.com/astaxie/beego/pkg/core/config/xml" + _ "github.com/astaxie/beego/core/config/xml" ) diff --git a/pkg/adapter/config/xml/xml_test.go b/adapter/config/xml/xml_test.go similarity index 98% rename from pkg/adapter/config/xml/xml_test.go rename to adapter/config/xml/xml_test.go index 122c5027..ae9b209e 100644 --- a/pkg/adapter/config/xml/xml_test.go +++ b/adapter/config/xml/xml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/adapter/config" + "github.com/astaxie/beego/adapter/config" ) func TestXML(t *testing.T) { diff --git a/pkg/adapter/config/yaml/yaml.go b/adapter/config/yaml/yaml.go similarity index 95% rename from pkg/adapter/config/yaml/yaml.go rename to adapter/config/yaml/yaml.go index c5325ccd..196c9725 100644 --- a/pkg/adapter/config/yaml/yaml.go +++ b/adapter/config/yaml/yaml.go @@ -30,5 +30,5 @@ package yaml import ( - _ "github.com/astaxie/beego/pkg/core/config/yaml" + _ "github.com/astaxie/beego/core/config/yaml" ) diff --git a/pkg/adapter/config/yaml/yaml_test.go b/adapter/config/yaml/yaml_test.go similarity index 98% rename from pkg/adapter/config/yaml/yaml_test.go rename to adapter/config/yaml/yaml_test.go index e4e309a2..a72e435e 100644 --- a/pkg/adapter/config/yaml/yaml_test.go +++ b/adapter/config/yaml/yaml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/adapter/config" + "github.com/astaxie/beego/adapter/config" ) func TestYaml(t *testing.T) { diff --git a/pkg/adapter/context/acceptencoder.go b/adapter/context/acceptencoder.go similarity index 96% rename from pkg/adapter/context/acceptencoder.go rename to adapter/context/acceptencoder.go index e578de45..4bfef95e 100644 --- a/pkg/adapter/context/acceptencoder.go +++ b/adapter/context/acceptencoder.go @@ -19,7 +19,7 @@ import ( "net/http" "os" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // InitGzip init the gzipcompress diff --git a/pkg/adapter/context/context.go b/adapter/context/context.go similarity index 98% rename from pkg/adapter/context/context.go rename to adapter/context/context.go index f9d8c624..123fdb2c 100644 --- a/pkg/adapter/context/context.go +++ b/adapter/context/context.go @@ -27,7 +27,7 @@ import ( "net" "net/http" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // commonly used mime-types diff --git a/pkg/adapter/context/input.go b/adapter/context/input.go similarity index 99% rename from pkg/adapter/context/input.go rename to adapter/context/input.go index a1d08855..4d62d3c1 100644 --- a/pkg/adapter/context/input.go +++ b/adapter/context/input.go @@ -15,7 +15,7 @@ package context import ( - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // BeegoInput operates the http request header, data, cookie and body. diff --git a/pkg/adapter/context/output.go b/adapter/context/output.go similarity index 99% rename from pkg/adapter/context/output.go rename to adapter/context/output.go index 8e2a7f7d..0223679b 100644 --- a/pkg/adapter/context/output.go +++ b/adapter/context/output.go @@ -15,7 +15,7 @@ package context import ( - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // BeegoOutput does work for sending response header. diff --git a/pkg/adapter/context/renderer.go b/adapter/context/renderer.go similarity index 67% rename from pkg/adapter/context/renderer.go rename to adapter/context/renderer.go index 763fb9c4..1309365a 100644 --- a/pkg/adapter/context/renderer.go +++ b/adapter/context/renderer.go @@ -1,7 +1,7 @@ package context import ( - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // Renderer defines an http response renderer diff --git a/pkg/adapter/context/response.go b/adapter/context/response.go similarity index 100% rename from pkg/adapter/context/response.go rename to adapter/context/response.go diff --git a/pkg/adapter/controller.go b/adapter/controller.go similarity index 98% rename from pkg/adapter/controller.go rename to adapter/controller.go index c0616962..14dc9b97 100644 --- a/pkg/adapter/controller.go +++ b/adapter/controller.go @@ -18,10 +18,10 @@ import ( "mime/multipart" "net/url" - "github.com/astaxie/beego/pkg/adapter/session" - webContext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/session" + webContext "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) var ( diff --git a/pkg/adapter/doc.go b/adapter/doc.go similarity index 100% rename from pkg/adapter/doc.go rename to adapter/doc.go diff --git a/pkg/adapter/error.go b/adapter/error.go similarity index 96% rename from pkg/adapter/error.go rename to adapter/error.go index 4f08aa8c..35ff7f35 100644 --- a/pkg/adapter/error.go +++ b/adapter/error.go @@ -17,10 +17,10 @@ package adapter import ( "net/http" - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) const ( diff --git a/pkg/adapter/filter.go b/adapter/filter.go similarity index 90% rename from pkg/adapter/filter.go rename to adapter/filter.go index cafed773..283d8879 100644 --- a/pkg/adapter/filter.go +++ b/adapter/filter.go @@ -15,9 +15,9 @@ package adapter import ( - "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/server/web" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/server/web" + beecontext "github.com/astaxie/beego/server/web/context" ) // FilterFunc defines a filter function which is invoked before the controller handler is executed. diff --git a/pkg/adapter/flash.go b/adapter/flash.go similarity index 97% rename from pkg/adapter/flash.go rename to adapter/flash.go index 02e75ed6..2b47ee62 100644 --- a/pkg/adapter/flash.go +++ b/adapter/flash.go @@ -15,7 +15,7 @@ package adapter import ( - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) // FlashData is a tools to maintain data when using across request. diff --git a/pkg/adapter/fs.go b/adapter/fs.go similarity index 96% rename from pkg/adapter/fs.go rename to adapter/fs.go index 07054ca3..e48e75b5 100644 --- a/pkg/adapter/fs.go +++ b/adapter/fs.go @@ -18,7 +18,7 @@ import ( "net/http" "path/filepath" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) type FileSystem web.FileSystem diff --git a/pkg/adapter/grace/grace.go b/adapter/grace/grace.go similarity index 98% rename from pkg/adapter/grace/grace.go rename to adapter/grace/grace.go index 3775e395..75ceef21 100644 --- a/pkg/adapter/grace/grace.go +++ b/adapter/grace/grace.go @@ -46,7 +46,7 @@ import ( "net/http" "time" - "github.com/astaxie/beego/pkg/server/web/grace" + "github.com/astaxie/beego/server/web/grace" ) const ( diff --git a/pkg/adapter/grace/server.go b/adapter/grace/server.go similarity index 97% rename from pkg/adapter/grace/server.go rename to adapter/grace/server.go index 31c13f18..0dfb2fd6 100644 --- a/pkg/adapter/grace/server.go +++ b/adapter/grace/server.go @@ -3,7 +3,7 @@ package grace import ( "os" - "github.com/astaxie/beego/pkg/server/web/grace" + "github.com/astaxie/beego/server/web/grace" ) // Server embedded http.Server diff --git a/pkg/adapter/httplib/httplib.go b/adapter/httplib/httplib.go similarity index 99% rename from pkg/adapter/httplib/httplib.go rename to adapter/httplib/httplib.go index d2ef36c1..d9ff1ea5 100644 --- a/pkg/adapter/httplib/httplib.go +++ b/adapter/httplib/httplib.go @@ -38,7 +38,7 @@ import ( "net/url" "time" - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" ) // SetDefaultSetting Overwrite default settings diff --git a/pkg/adapter/httplib/httplib_test.go b/adapter/httplib/httplib_test.go similarity index 100% rename from pkg/adapter/httplib/httplib_test.go rename to adapter/httplib/httplib_test.go diff --git a/pkg/adapter/log.go b/adapter/log.go similarity index 97% rename from pkg/adapter/log.go rename to adapter/log.go index 0d7d94c0..9d07ec1a 100644 --- a/pkg/adapter/log.go +++ b/adapter/log.go @@ -17,9 +17,9 @@ package adapter import ( "strings" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - webLog "github.com/astaxie/beego/pkg/core/logs" + webLog "github.com/astaxie/beego/core/logs" ) // Log levels to control the logging output. diff --git a/pkg/adapter/logs/accesslog.go b/adapter/logs/accesslog.go similarity index 95% rename from pkg/adapter/logs/accesslog.go rename to adapter/logs/accesslog.go index cebee92b..a2150884 100644 --- a/pkg/adapter/logs/accesslog.go +++ b/adapter/logs/accesslog.go @@ -15,7 +15,7 @@ package logs import ( - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // AccessLogRecord struct for holding access log data. diff --git a/adapter/logs/alils/alils.go b/adapter/logs/alils/alils.go new file mode 100644 index 00000000..941cba4c --- /dev/null +++ b/adapter/logs/alils/alils.go @@ -0,0 +1,5 @@ +package alils + +import ( + _ "github.com/astaxie/beego/core/logs/alils" +) diff --git a/adapter/logs/es/es.go b/adapter/logs/es/es.go new file mode 100644 index 00000000..0f0fd607 --- /dev/null +++ b/adapter/logs/es/es.go @@ -0,0 +1,5 @@ +package es + +import ( + _ "github.com/astaxie/beego/core/logs/es" +) diff --git a/pkg/adapter/logs/log.go b/adapter/logs/log.go similarity index 99% rename from pkg/adapter/logs/log.go rename to adapter/logs/log.go index 185b2a47..54eb24d5 100644 --- a/pkg/adapter/logs/log.go +++ b/adapter/logs/log.go @@ -37,7 +37,7 @@ import ( "log" "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // RFC5424 log message levels. diff --git a/pkg/adapter/logs/log_adapter.go b/adapter/logs/log_adapter.go similarity index 97% rename from pkg/adapter/logs/log_adapter.go rename to adapter/logs/log_adapter.go index ee517bf0..6b7022d6 100644 --- a/pkg/adapter/logs/log_adapter.go +++ b/adapter/logs/log_adapter.go @@ -17,7 +17,7 @@ package logs import ( "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) type oldToNewAdapter struct { diff --git a/pkg/adapter/logs/logger.go b/adapter/logs/logger.go similarity index 96% rename from pkg/adapter/logs/logger.go rename to adapter/logs/logger.go index 419ac9c4..5a8e0a1c 100644 --- a/pkg/adapter/logs/logger.go +++ b/adapter/logs/logger.go @@ -15,7 +15,7 @@ package logs import ( - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // ColorByStatus return color by http code diff --git a/pkg/adapter/logs/logger_test.go b/adapter/logs/logger_test.go similarity index 100% rename from pkg/adapter/logs/logger_test.go rename to adapter/logs/logger_test.go diff --git a/pkg/adapter/metric/prometheus.go b/adapter/metric/prometheus.go similarity index 87% rename from pkg/adapter/metric/prometheus.go rename to adapter/metric/prometheus.go index df5db84f..4660f626 100644 --- a/pkg/adapter/metric/prometheus.go +++ b/adapter/metric/prometheus.go @@ -23,9 +23,9 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego" + "github.com/astaxie/beego/core/logs" + "github.com/astaxie/beego/server/web" ) func PrometheusMiddleWare(next http.Handler) http.Handler { @@ -59,13 +59,13 @@ func registerBuildInfo() { Help: "The building information", ConstLabels: map[string]string{ "appname": web.BConfig.AppName, - "build_version": pkg.BuildVersion, - "build_revision": pkg.BuildGitRevision, - "build_status": pkg.BuildStatus, - "build_tag": pkg.BuildTag, - "build_time": strings.Replace(pkg.BuildTime, "--", " ", 1), - "go_version": pkg.GoVersion, - "git_branch": pkg.GitBranch, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/adapter/metric/prometheus_test.go b/adapter/metric/prometheus_test.go similarity index 96% rename from pkg/adapter/metric/prometheus_test.go rename to adapter/metric/prometheus_test.go index 87286e02..751348bf 100644 --- a/pkg/adapter/metric/prometheus_test.go +++ b/adapter/metric/prometheus_test.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/adapter/context" ) func TestPrometheusMiddleWare(t *testing.T) { diff --git a/pkg/adapter/migration/ddl.go b/adapter/migration/ddl.go similarity index 99% rename from pkg/adapter/migration/ddl.go rename to adapter/migration/ddl.go index 97e45dec..b43b4d34 100644 --- a/pkg/adapter/migration/ddl.go +++ b/adapter/migration/ddl.go @@ -15,7 +15,7 @@ package migration import ( - "github.com/astaxie/beego/pkg/client/orm/migration" + "github.com/astaxie/beego/client/orm/migration" ) // Index struct defines the structure of Index Columns diff --git a/pkg/adapter/migration/doc.go b/adapter/migration/doc.go similarity index 100% rename from pkg/adapter/migration/doc.go rename to adapter/migration/doc.go diff --git a/pkg/adapter/migration/migration.go b/adapter/migration/migration.go similarity index 98% rename from pkg/adapter/migration/migration.go rename to adapter/migration/migration.go index 4ee22e5a..677c35ca 100644 --- a/pkg/adapter/migration/migration.go +++ b/adapter/migration/migration.go @@ -28,7 +28,7 @@ package migration import ( - "github.com/astaxie/beego/pkg/client/orm/migration" + "github.com/astaxie/beego/client/orm/migration" ) // const the data format for the bee generate migration datatype diff --git a/pkg/adapter/namespace.go b/adapter/namespace.go similarity index 98% rename from pkg/adapter/namespace.go rename to adapter/namespace.go index 609402cf..98cbd8a5 100644 --- a/pkg/adapter/namespace.go +++ b/adapter/namespace.go @@ -17,10 +17,10 @@ package adapter import ( "net/http" - adtContext "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/server/web/context" + adtContext "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) type namespaceCond func(*adtContext.Context) bool diff --git a/pkg/adapter/orm/cmd.go b/adapter/orm/cmd.go similarity index 95% rename from pkg/adapter/orm/cmd.go rename to adapter/orm/cmd.go index 6fee237c..fcbd1be4 100644 --- a/pkg/adapter/orm/cmd.go +++ b/adapter/orm/cmd.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // RunCommand listen for orm command and then run it if command arguments passed. diff --git a/pkg/adapter/orm/db.go b/adapter/orm/db.go similarity index 94% rename from pkg/adapter/orm/db.go rename to adapter/orm/db.go index 74bca8c0..fd878732 100644 --- a/pkg/adapter/orm/db.go +++ b/adapter/orm/db.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) var ( diff --git a/pkg/adapter/orm/db_alias.go b/adapter/orm/db_alias.go similarity index 98% rename from pkg/adapter/orm/db_alias.go rename to adapter/orm/db_alias.go index 523b6aee..81a07207 100644 --- a/pkg/adapter/orm/db_alias.go +++ b/adapter/orm/db_alias.go @@ -19,7 +19,7 @@ import ( "database/sql" "time" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // DriverType database driver constant int. diff --git a/pkg/adapter/orm/models.go b/adapter/orm/models.go similarity index 94% rename from pkg/adapter/orm/models.go rename to adapter/orm/models.go index 3215f5b5..5df64d6d 100644 --- a/pkg/adapter/orm/models.go +++ b/adapter/orm/models.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // ResetModelCache Clean model cache. Then you can re-RegisterModel. diff --git a/pkg/adapter/orm/models_boot.go b/adapter/orm/models_boot.go similarity index 96% rename from pkg/adapter/orm/models_boot.go rename to adapter/orm/models_boot.go index 8888ef65..0b07de59 100644 --- a/pkg/adapter/orm/models_boot.go +++ b/adapter/orm/models_boot.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // RegisterModel register models diff --git a/pkg/adapter/orm/models_fields.go b/adapter/orm/models_fields.go similarity index 99% rename from pkg/adapter/orm/models_fields.go rename to adapter/orm/models_fields.go index 666a97dc..6210567b 100644 --- a/pkg/adapter/orm/models_fields.go +++ b/adapter/orm/models_fields.go @@ -17,7 +17,7 @@ package orm import ( "time" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // Define the Type enum diff --git a/pkg/adapter/orm/orm.go b/adapter/orm/orm.go similarity index 98% rename from pkg/adapter/orm/orm.go rename to adapter/orm/orm.go index 61990256..15df76ed 100644 --- a/pkg/adapter/orm/orm.go +++ b/adapter/orm/orm.go @@ -58,9 +58,9 @@ import ( "database/sql" "errors" - "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/client/orm" + "github.com/astaxie/beego/client/orm/hints" + "github.com/astaxie/beego/core/utils" ) // DebugQueries define the debug diff --git a/pkg/adapter/orm/orm_conds.go b/adapter/orm/orm_conds.go similarity index 98% rename from pkg/adapter/orm/orm_conds.go rename to adapter/orm/orm_conds.go index 986b4858..f70f0f5b 100644 --- a/pkg/adapter/orm/orm_conds.go +++ b/adapter/orm/orm_conds.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // ExprSep define the expression separation diff --git a/pkg/adapter/orm/orm_log.go b/adapter/orm/orm_log.go similarity index 95% rename from pkg/adapter/orm/orm_log.go rename to adapter/orm/orm_log.go index 6b2b4a9b..3ff7f01c 100644 --- a/pkg/adapter/orm/orm_log.go +++ b/adapter/orm/orm_log.go @@ -17,7 +17,7 @@ package orm import ( "io" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // Log implement the log.Logger diff --git a/pkg/adapter/orm/orm_queryset.go b/adapter/orm/orm_queryset.go similarity index 95% rename from pkg/adapter/orm/orm_queryset.go rename to adapter/orm/orm_queryset.go index 5f211644..1926a6c0 100644 --- a/pkg/adapter/orm/orm_queryset.go +++ b/adapter/orm/orm_queryset.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // define Col operations diff --git a/pkg/adapter/orm/qb.go b/adapter/orm/qb.go similarity index 95% rename from pkg/adapter/orm/qb.go rename to adapter/orm/qb.go index 90b97797..63eaed8a 100644 --- a/pkg/adapter/orm/qb.go +++ b/adapter/orm/qb.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // QueryBuilder is the Query builder interface diff --git a/pkg/adapter/orm/qb_mysql.go b/adapter/orm/qb_mysql.go similarity index 99% rename from pkg/adapter/orm/qb_mysql.go rename to adapter/orm/qb_mysql.go index 9566068f..ef87ebab 100644 --- a/pkg/adapter/orm/qb_mysql.go +++ b/adapter/orm/qb_mysql.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // CommaSpace is the separation diff --git a/pkg/adapter/orm/qb_tidb.go b/adapter/orm/qb_tidb.go similarity index 99% rename from pkg/adapter/orm/qb_tidb.go rename to adapter/orm/qb_tidb.go index 05c91a26..18631ef0 100644 --- a/pkg/adapter/orm/qb_tidb.go +++ b/adapter/orm/qb_tidb.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // TiDBQueryBuilder is the SQL build diff --git a/pkg/adapter/orm/query_setter_adapter.go b/adapter/orm/query_setter_adapter.go similarity index 95% rename from pkg/adapter/orm/query_setter_adapter.go rename to adapter/orm/query_setter_adapter.go index cc24ef6b..d6c268b6 100644 --- a/pkg/adapter/orm/query_setter_adapter.go +++ b/adapter/orm/query_setter_adapter.go @@ -15,7 +15,7 @@ package orm import ( - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) type baseQuerySetter struct { diff --git a/pkg/adapter/orm/types.go b/adapter/orm/types.go similarity index 99% rename from pkg/adapter/orm/types.go rename to adapter/orm/types.go index 3372e301..6db5066c 100644 --- a/pkg/adapter/orm/types.go +++ b/adapter/orm/types.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // Params stores the Params diff --git a/pkg/adapter/orm/utils.go b/adapter/orm/utils.go similarity index 99% rename from pkg/adapter/orm/utils.go rename to adapter/orm/utils.go index 16d0e4e5..37ba86d8 100644 --- a/pkg/adapter/orm/utils.go +++ b/adapter/orm/utils.go @@ -21,7 +21,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) type fn func(string) string diff --git a/pkg/adapter/orm/utils_test.go b/adapter/orm/utils_test.go similarity index 100% rename from pkg/adapter/orm/utils_test.go rename to adapter/orm/utils_test.go diff --git a/pkg/adapter/plugins/apiauth/apiauth.go b/adapter/plugins/apiauth/apiauth.go similarity index 92% rename from pkg/adapter/plugins/apiauth/apiauth.go rename to adapter/plugins/apiauth/apiauth.go index ed43f8a0..90311d8f 100644 --- a/pkg/adapter/plugins/apiauth/apiauth.go +++ b/adapter/plugins/apiauth/apiauth.go @@ -58,10 +58,10 @@ package apiauth import ( "net/url" - beego "github.com/astaxie/beego/pkg/adapter" - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/filter/apiauth" + beego "github.com/astaxie/beego/adapter" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/filter/apiauth" ) // AppIDToAppSecret is used to get appsecret throw appid diff --git a/pkg/adapter/plugins/apiauth/apiauth_test.go b/adapter/plugins/apiauth/apiauth_test.go similarity index 100% rename from pkg/adapter/plugins/apiauth/apiauth_test.go rename to adapter/plugins/apiauth/apiauth_test.go diff --git a/pkg/adapter/plugins/auth/basic.go b/adapter/plugins/auth/basic.go similarity index 92% rename from pkg/adapter/plugins/auth/basic.go rename to adapter/plugins/auth/basic.go index 7a9cd326..578a16d9 100644 --- a/pkg/adapter/plugins/auth/basic.go +++ b/adapter/plugins/auth/basic.go @@ -38,10 +38,10 @@ package auth import ( "net/http" - beego "github.com/astaxie/beego/pkg/adapter" - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/filter/auth" + beego "github.com/astaxie/beego/adapter" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/filter/auth" ) // Basic is the http basic auth diff --git a/pkg/adapter/plugins/authz/authz.go b/adapter/plugins/authz/authz.go similarity index 91% rename from pkg/adapter/plugins/authz/authz.go rename to adapter/plugins/authz/authz.go index c38be9cb..3f84467e 100644 --- a/pkg/adapter/plugins/authz/authz.go +++ b/adapter/plugins/authz/authz.go @@ -44,10 +44,10 @@ import ( "github.com/casbin/casbin" - beego "github.com/astaxie/beego/pkg/adapter" - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/filter/authz" + beego "github.com/astaxie/beego/adapter" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/filter/authz" ) // NewAuthorizer returns the authorizer. diff --git a/pkg/adapter/plugins/authz/authz_model.conf b/adapter/plugins/authz/authz_model.conf similarity index 100% rename from pkg/adapter/plugins/authz/authz_model.conf rename to adapter/plugins/authz/authz_model.conf diff --git a/pkg/adapter/plugins/authz/authz_policy.csv b/adapter/plugins/authz/authz_policy.csv similarity index 100% rename from pkg/adapter/plugins/authz/authz_policy.csv rename to adapter/plugins/authz/authz_policy.csv diff --git a/pkg/adapter/plugins/authz/authz_test.go b/adapter/plugins/authz/authz_test.go similarity index 96% rename from pkg/adapter/plugins/authz/authz_test.go rename to adapter/plugins/authz/authz_test.go index ddbda5f4..9b4f21c2 100644 --- a/pkg/adapter/plugins/authz/authz_test.go +++ b/adapter/plugins/authz/authz_test.go @@ -19,9 +19,9 @@ import ( "net/http/httptest" "testing" - beego "github.com/astaxie/beego/pkg/adapter" - "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/adapter/plugins/auth" + beego "github.com/astaxie/beego/adapter" + "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/adapter/plugins/auth" "github.com/casbin/casbin" ) diff --git a/pkg/adapter/plugins/cors/cors.go b/adapter/plugins/cors/cors.go similarity index 91% rename from pkg/adapter/plugins/cors/cors.go rename to adapter/plugins/cors/cors.go index 65af8b8f..a15d5417 100644 --- a/pkg/adapter/plugins/cors/cors.go +++ b/adapter/plugins/cors/cors.go @@ -36,11 +36,11 @@ package cors import ( - beego "github.com/astaxie/beego/pkg/adapter" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/filter/cors" + beego "github.com/astaxie/beego/adapter" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/filter/cors" - "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/adapter/context" ) // Options represents Access Control options. diff --git a/pkg/adapter/policy.go b/adapter/policy.go similarity index 91% rename from pkg/adapter/policy.go rename to adapter/policy.go index f3759c76..6f334d2d 100644 --- a/pkg/adapter/policy.go +++ b/adapter/policy.go @@ -15,9 +15,9 @@ package adapter import ( - "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/server/web" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/server/web" + beecontext "github.com/astaxie/beego/server/web/context" ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. diff --git a/pkg/adapter/router.go b/adapter/router.go similarity index 98% rename from pkg/adapter/router.go rename to adapter/router.go index 8e8d9fdb..c91a09f1 100644 --- a/pkg/adapter/router.go +++ b/adapter/router.go @@ -18,10 +18,10 @@ import ( "net/http" "time" - beecontext "github.com/astaxie/beego/pkg/adapter/context" - "github.com/astaxie/beego/pkg/server/web/context" + beecontext "github.com/astaxie/beego/adapter/context" + "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) // default filter execution points diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/adapter/session/couchbase/sess_couchbase.go similarity index 97% rename from pkg/adapter/session/couchbase/sess_couchbase.go rename to adapter/session/couchbase/sess_couchbase.go index 2903dae5..b6afb612 100644 --- a/pkg/adapter/session/couchbase/sess_couchbase.go +++ b/adapter/session/couchbase/sess_couchbase.go @@ -36,8 +36,8 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" - beecb "github.com/astaxie/beego/pkg/server/web/session/couchbase" + "github.com/astaxie/beego/adapter/session" + beecb "github.com/astaxie/beego/server/web/session/couchbase" ) // SessionStore store each session diff --git a/pkg/adapter/session/ledis/ledis_session.go b/adapter/session/ledis/ledis_session.go similarity index 95% rename from pkg/adapter/session/ledis/ledis_session.go rename to adapter/session/ledis/ledis_session.go index 92d3c96a..350cbdaa 100644 --- a/pkg/adapter/session/ledis/ledis_session.go +++ b/adapter/session/ledis/ledis_session.go @@ -5,8 +5,8 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" - beeLedis "github.com/astaxie/beego/pkg/server/web/session/ledis" + "github.com/astaxie/beego/adapter/session" + beeLedis "github.com/astaxie/beego/server/web/session/ledis" ) // SessionStore ledis session store diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/adapter/session/memcache/sess_memcache.go similarity index 97% rename from pkg/adapter/session/memcache/sess_memcache.go rename to adapter/session/memcache/sess_memcache.go index ae985881..772839cd 100644 --- a/pkg/adapter/session/memcache/sess_memcache.go +++ b/adapter/session/memcache/sess_memcache.go @@ -36,9 +36,9 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" - beemem "github.com/astaxie/beego/pkg/server/web/session/memcache" + beemem "github.com/astaxie/beego/server/web/session/memcache" ) // SessionStore memcache session store diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/adapter/session/mysql/sess_mysql.go similarity index 97% rename from pkg/adapter/session/mysql/sess_mysql.go rename to adapter/session/mysql/sess_mysql.go index 73113792..5d7e1dac 100644 --- a/pkg/adapter/session/mysql/sess_mysql.go +++ b/adapter/session/mysql/sess_mysql.go @@ -44,8 +44,8 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" - "github.com/astaxie/beego/pkg/server/web/session/mysql" + "github.com/astaxie/beego/adapter/session" + "github.com/astaxie/beego/server/web/session/mysql" // import mysql driver _ "github.com/go-sql-driver/mysql" diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/adapter/session/postgres/sess_postgresql.go similarity index 97% rename from pkg/adapter/session/postgres/sess_postgresql.go rename to adapter/session/postgres/sess_postgresql.go index 21a360e8..879b2b83 100644 --- a/pkg/adapter/session/postgres/sess_postgresql.go +++ b/adapter/session/postgres/sess_postgresql.go @@ -54,11 +54,11 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" // import postgresql Driver _ "github.com/lib/pq" - "github.com/astaxie/beego/pkg/server/web/session/postgres" + "github.com/astaxie/beego/server/web/session/postgres" ) // SessionStore postgresql session store diff --git a/pkg/adapter/session/provider_adapter.go b/adapter/session/provider_adapter.go similarity index 98% rename from pkg/adapter/session/provider_adapter.go rename to adapter/session/provider_adapter.go index 84cb4c85..596bc6a6 100644 --- a/pkg/adapter/session/provider_adapter.go +++ b/adapter/session/provider_adapter.go @@ -17,7 +17,7 @@ package session import ( "context" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) type oldToNewProviderAdapter struct { diff --git a/pkg/adapter/session/redis/sess_redis.go b/adapter/session/redis/sess_redis.go similarity index 97% rename from pkg/adapter/session/redis/sess_redis.go rename to adapter/session/redis/sess_redis.go index 47c9d437..bb8e8be4 100644 --- a/pkg/adapter/session/redis/sess_redis.go +++ b/adapter/session/redis/sess_redis.go @@ -36,9 +36,9 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" - beeRedis "github.com/astaxie/beego/pkg/server/web/session/redis" + beeRedis "github.com/astaxie/beego/server/web/session/redis" ) // MaxPoolSize redis max pool size diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/adapter/session/redis_cluster/redis_cluster.go similarity index 97% rename from pkg/adapter/session/redis_cluster/redis_cluster.go rename to adapter/session/redis_cluster/redis_cluster.go index b741b9ff..1be22cd4 100644 --- a/pkg/adapter/session/redis_cluster/redis_cluster.go +++ b/adapter/session/redis_cluster/redis_cluster.go @@ -36,8 +36,8 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" - cluster "github.com/astaxie/beego/pkg/server/web/session/redis_cluster" + "github.com/astaxie/beego/adapter/session" + cluster "github.com/astaxie/beego/server/web/session/redis_cluster" ) // MaxPoolSize redis_cluster max pool size diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/adapter/session/redis_sentinel/sess_redis_sentinel.go similarity index 97% rename from pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go rename to adapter/session/redis_sentinel/sess_redis_sentinel.go index 99ff7898..7ab9e7c5 100644 --- a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -36,9 +36,9 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" - sentinel "github.com/astaxie/beego/pkg/server/web/session/redis_sentinel" + sentinel "github.com/astaxie/beego/server/web/session/redis_sentinel" ) // DefaultPoolSize redis_sentinel default pool size diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go similarity index 97% rename from pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go rename to adapter/session/redis_sentinel/sess_redis_sentinel_test.go index 7c33985f..407d32ab 100644 --- a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" ) func TestRedisSentinel(t *testing.T) { diff --git a/pkg/adapter/session/sess_cookie.go b/adapter/session/sess_cookie.go similarity index 98% rename from pkg/adapter/session/sess_cookie.go rename to adapter/session/sess_cookie.go index 8c6c1dc7..3fcbd28e 100644 --- a/pkg/adapter/session/sess_cookie.go +++ b/adapter/session/sess_cookie.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // CookieSessionStore Cookie SessionStore diff --git a/pkg/adapter/session/sess_cookie_test.go b/adapter/session/sess_cookie_test.go similarity index 100% rename from pkg/adapter/session/sess_cookie_test.go rename to adapter/session/sess_cookie_test.go diff --git a/pkg/adapter/session/sess_file.go b/adapter/session/sess_file.go similarity index 98% rename from pkg/adapter/session/sess_file.go rename to adapter/session/sess_file.go index 870b62a6..2ba33e6d 100644 --- a/pkg/adapter/session/sess_file.go +++ b/adapter/session/sess_file.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // FileSessionStore File session store diff --git a/pkg/adapter/session/sess_file_test.go b/adapter/session/sess_file_test.go similarity index 100% rename from pkg/adapter/session/sess_file_test.go rename to adapter/session/sess_file_test.go diff --git a/pkg/adapter/session/sess_mem.go b/adapter/session/sess_mem.go similarity index 98% rename from pkg/adapter/session/sess_mem.go rename to adapter/session/sess_mem.go index faaab548..febed719 100644 --- a/pkg/adapter/session/sess_mem.go +++ b/adapter/session/sess_mem.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // MemSessionStore memory session store. diff --git a/pkg/adapter/session/sess_mem_test.go b/adapter/session/sess_mem_test.go similarity index 100% rename from pkg/adapter/session/sess_mem_test.go rename to adapter/session/sess_mem_test.go diff --git a/pkg/adapter/session/sess_test.go b/adapter/session/sess_test.go similarity index 100% rename from pkg/adapter/session/sess_test.go rename to adapter/session/sess_test.go diff --git a/pkg/adapter/session/sess_utils.go b/adapter/session/sess_utils.go similarity index 94% rename from pkg/adapter/session/sess_utils.go rename to adapter/session/sess_utils.go index 8cf036e4..4cfdc760 100644 --- a/pkg/adapter/session/sess_utils.go +++ b/adapter/session/sess_utils.go @@ -15,7 +15,7 @@ package session import ( - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // EncodeGob encode the obj to gob diff --git a/pkg/adapter/session/session.go b/adapter/session/session.go similarity index 99% rename from pkg/adapter/session/session.go rename to adapter/session/session.go index 24f587b6..d8b151b7 100644 --- a/pkg/adapter/session/session.go +++ b/adapter/session/session.go @@ -32,7 +32,7 @@ import ( "net/http" "os" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // Store contains all data for one session process with specific id. diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/adapter/session/ssdb/sess_ssdb.go similarity index 95% rename from pkg/adapter/session/ssdb/sess_ssdb.go rename to adapter/session/ssdb/sess_ssdb.go index 3f2d08d9..cd9c4a24 100644 --- a/pkg/adapter/session/ssdb/sess_ssdb.go +++ b/adapter/session/ssdb/sess_ssdb.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/adapter/session" - beeSsdb "github.com/astaxie/beego/pkg/server/web/session/ssdb" + beeSsdb "github.com/astaxie/beego/server/web/session/ssdb" ) // Provider holds ssdb client and configs diff --git a/pkg/adapter/session/store_adapter.go b/adapter/session/store_adapter.go similarity index 97% rename from pkg/adapter/session/store_adapter.go rename to adapter/session/store_adapter.go index c0de6ac3..70ad83e2 100644 --- a/pkg/adapter/session/store_adapter.go +++ b/adapter/session/store_adapter.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) type NewToOldStoreAdapter struct { diff --git a/pkg/adapter/swagger/swagger.go b/adapter/swagger/swagger.go similarity index 98% rename from pkg/adapter/swagger/swagger.go rename to adapter/swagger/swagger.go index 214959d9..7a44b770 100644 --- a/pkg/adapter/swagger/swagger.go +++ b/adapter/swagger/swagger.go @@ -21,7 +21,7 @@ package swagger import ( - "github.com/astaxie/beego/pkg/server/web/swagger" + "github.com/astaxie/beego/server/web/swagger" ) // Swagger list the resource diff --git a/pkg/adapter/template.go b/adapter/template.go similarity index 98% rename from pkg/adapter/template.go rename to adapter/template.go index 1f943caf..67f5a33b 100644 --- a/pkg/adapter/template.go +++ b/adapter/template.go @@ -19,7 +19,7 @@ import ( "io" "net/http" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) // ExecuteTemplate applies the template with name to the specified data object, diff --git a/pkg/adapter/templatefunc.go b/adapter/templatefunc.go similarity index 98% rename from pkg/adapter/templatefunc.go rename to adapter/templatefunc.go index 5130d590..0c805393 100644 --- a/pkg/adapter/templatefunc.go +++ b/adapter/templatefunc.go @@ -19,7 +19,7 @@ import ( "net/url" "time" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) const ( diff --git a/pkg/adapter/templatefunc_test.go b/adapter/templatefunc_test.go similarity index 100% rename from pkg/adapter/templatefunc_test.go rename to adapter/templatefunc_test.go diff --git a/pkg/adapter/testing/client.go b/adapter/testing/client.go similarity index 96% rename from pkg/adapter/testing/client.go rename to adapter/testing/client.go index 688aa6f3..5c138167 100644 --- a/pkg/adapter/testing/client.go +++ b/adapter/testing/client.go @@ -15,7 +15,7 @@ package testing import ( - "github.com/astaxie/beego/pkg/client/httplib/testing" + "github.com/astaxie/beego/client/httplib/testing" ) var port = "" diff --git a/pkg/adapter/toolbox/healthcheck.go b/adapter/toolbox/healthcheck.go similarity index 96% rename from pkg/adapter/toolbox/healthcheck.go rename to adapter/toolbox/healthcheck.go index 42b9e7d0..7d89c2fb 100644 --- a/pkg/adapter/toolbox/healthcheck.go +++ b/adapter/toolbox/healthcheck.go @@ -31,7 +31,7 @@ package toolbox import ( - "github.com/astaxie/beego/pkg/core/governor" + "github.com/astaxie/beego/core/governor" ) // AdminCheckList holds health checker map diff --git a/pkg/adapter/toolbox/profile.go b/adapter/toolbox/profile.go similarity index 96% rename from pkg/adapter/toolbox/profile.go rename to adapter/toolbox/profile.go index 97da05ac..a5434360 100644 --- a/pkg/adapter/toolbox/profile.go +++ b/adapter/toolbox/profile.go @@ -19,7 +19,7 @@ import ( "os" "time" - "github.com/astaxie/beego/pkg/core/governor" + "github.com/astaxie/beego/core/governor" ) var startTime = time.Now() diff --git a/pkg/adapter/toolbox/profile_test.go b/adapter/toolbox/profile_test.go similarity index 100% rename from pkg/adapter/toolbox/profile_test.go rename to adapter/toolbox/profile_test.go diff --git a/pkg/adapter/toolbox/statistics.go b/adapter/toolbox/statistics.go similarity index 97% rename from pkg/adapter/toolbox/statistics.go rename to adapter/toolbox/statistics.go index b7d3bda9..7c8cd75e 100644 --- a/pkg/adapter/toolbox/statistics.go +++ b/adapter/toolbox/statistics.go @@ -17,7 +17,7 @@ package toolbox import ( "time" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) // Statistics struct diff --git a/pkg/adapter/toolbox/statistics_test.go b/adapter/toolbox/statistics_test.go similarity index 100% rename from pkg/adapter/toolbox/statistics_test.go rename to adapter/toolbox/statistics_test.go diff --git a/pkg/adapter/toolbox/task.go b/adapter/toolbox/task.go similarity index 99% rename from pkg/adapter/toolbox/task.go rename to adapter/toolbox/task.go index 5b2fa14c..7f1bfc45 100644 --- a/pkg/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -19,7 +19,7 @@ import ( "sort" "time" - "github.com/astaxie/beego/pkg/task" + "github.com/astaxie/beego/task" ) // The bounds for each field. diff --git a/pkg/adapter/toolbox/task_test.go b/adapter/toolbox/task_test.go similarity index 100% rename from pkg/adapter/toolbox/task_test.go rename to adapter/toolbox/task_test.go diff --git a/pkg/adapter/tree.go b/adapter/tree.go similarity index 90% rename from pkg/adapter/tree.go rename to adapter/tree.go index 2e3cd0d0..36f763ea 100644 --- a/pkg/adapter/tree.go +++ b/adapter/tree.go @@ -15,10 +15,10 @@ package adapter import ( - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/server/web" ) // Tree has three elements: FixRouter/wildcard/leaves diff --git a/pkg/adapter/tree_test.go b/adapter/tree_test.go similarity index 99% rename from pkg/adapter/tree_test.go rename to adapter/tree_test.go index 309ed072..2315d829 100644 --- a/pkg/adapter/tree_test.go +++ b/adapter/tree_test.go @@ -17,8 +17,8 @@ package adapter import ( "testing" - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" ) type testinfo struct { diff --git a/pkg/adapter/utils/caller.go b/adapter/utils/caller.go similarity index 94% rename from pkg/adapter/utils/caller.go rename to adapter/utils/caller.go index 124c68df..419f11d6 100644 --- a/pkg/adapter/utils/caller.go +++ b/adapter/utils/caller.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // GetFuncName get function name diff --git a/pkg/adapter/utils/caller_test.go b/adapter/utils/caller_test.go similarity index 100% rename from pkg/adapter/utils/caller_test.go rename to adapter/utils/caller_test.go diff --git a/pkg/adapter/utils/captcha/LICENSE b/adapter/utils/captcha/LICENSE similarity index 100% rename from pkg/adapter/utils/captcha/LICENSE rename to adapter/utils/captcha/LICENSE diff --git a/pkg/adapter/utils/captcha/README.md b/adapter/utils/captcha/README.md similarity index 100% rename from pkg/adapter/utils/captcha/README.md rename to adapter/utils/captcha/README.md diff --git a/pkg/adapter/utils/captcha/captcha.go b/adapter/utils/captcha/captcha.go similarity index 93% rename from pkg/adapter/utils/captcha/captcha.go rename to adapter/utils/captcha/captcha.go index aad3994b..71aad0f2 100644 --- a/pkg/adapter/utils/captcha/captcha.go +++ b/adapter/utils/captcha/captcha.go @@ -63,11 +63,11 @@ import ( "net/http" "time" - "github.com/astaxie/beego/pkg/server/web/captcha" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/captcha" + beecontext "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/adapter/cache" - "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/adapter/cache" + "github.com/astaxie/beego/adapter/context" ) var ( diff --git a/pkg/adapter/utils/captcha/image.go b/adapter/utils/captcha/image.go similarity index 95% rename from pkg/adapter/utils/captcha/image.go rename to adapter/utils/captcha/image.go index 9979db84..6a1b696b 100644 --- a/pkg/adapter/utils/captcha/image.go +++ b/adapter/utils/captcha/image.go @@ -17,7 +17,7 @@ package captcha import ( "io" - "github.com/astaxie/beego/pkg/server/web/captcha" + "github.com/astaxie/beego/server/web/captcha" ) // Image struct diff --git a/pkg/adapter/utils/captcha/image_test.go b/adapter/utils/captcha/image_test.go similarity index 96% rename from pkg/adapter/utils/captcha/image_test.go rename to adapter/utils/captcha/image_test.go index bce2134a..5d298573 100644 --- a/pkg/adapter/utils/captcha/image_test.go +++ b/adapter/utils/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/pkg/adapter/utils" + "github.com/astaxie/beego/adapter/utils" ) const ( diff --git a/pkg/adapter/utils/debug.go b/adapter/utils/debug.go similarity index 95% rename from pkg/adapter/utils/debug.go rename to adapter/utils/debug.go index 6bb381a1..3f4d2759 100644 --- a/pkg/adapter/utils/debug.go +++ b/adapter/utils/debug.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // Display print the data in console diff --git a/pkg/adapter/utils/debug_test.go b/adapter/utils/debug_test.go similarity index 100% rename from pkg/adapter/utils/debug_test.go rename to adapter/utils/debug_test.go diff --git a/pkg/adapter/utils/file.go b/adapter/utils/file.go similarity index 97% rename from pkg/adapter/utils/file.go rename to adapter/utils/file.go index 4ed2a8e3..aa9ac316 100644 --- a/pkg/adapter/utils/file.go +++ b/adapter/utils/file.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // SelfPath gets compiled executable file absolute path diff --git a/pkg/adapter/utils/mail.go b/adapter/utils/mail.go similarity index 97% rename from pkg/adapter/utils/mail.go rename to adapter/utils/mail.go index 74ebe25c..74a8f403 100644 --- a/pkg/adapter/utils/mail.go +++ b/adapter/utils/mail.go @@ -17,7 +17,7 @@ package utils import ( "io" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // Email is the type used for email messages diff --git a/pkg/adapter/utils/mail_test.go b/adapter/utils/mail_test.go similarity index 100% rename from pkg/adapter/utils/mail_test.go rename to adapter/utils/mail_test.go diff --git a/pkg/adapter/utils/pagination/controller.go b/adapter/utils/pagination/controller.go similarity index 84% rename from pkg/adapter/utils/pagination/controller.go rename to adapter/utils/pagination/controller.go index a908d8b0..c82c54f9 100644 --- a/pkg/adapter/utils/pagination/controller.go +++ b/adapter/utils/pagination/controller.go @@ -15,9 +15,9 @@ package pagination import ( - "github.com/astaxie/beego/pkg/adapter/context" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/pagination" + "github.com/astaxie/beego/adapter/context" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/pagination" ) // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). diff --git a/pkg/adapter/utils/pagination/doc.go b/adapter/utils/pagination/doc.go similarity index 100% rename from pkg/adapter/utils/pagination/doc.go rename to adapter/utils/pagination/doc.go diff --git a/pkg/adapter/utils/pagination/paginator.go b/adapter/utils/pagination/paginator.go similarity index 98% rename from pkg/adapter/utils/pagination/paginator.go rename to adapter/utils/pagination/paginator.go index 1fefb9e0..73d9157f 100644 --- a/pkg/adapter/utils/pagination/paginator.go +++ b/adapter/utils/pagination/paginator.go @@ -17,7 +17,7 @@ package pagination import ( "net/http" - "github.com/astaxie/beego/pkg/core/utils/pagination" + "github.com/astaxie/beego/core/utils/pagination" ) // Paginator within the state of a http request. diff --git a/pkg/adapter/utils/rand.go b/adapter/utils/rand.go similarity index 94% rename from pkg/adapter/utils/rand.go rename to adapter/utils/rand.go index c31b633e..0fcca580 100644 --- a/pkg/adapter/utils/rand.go +++ b/adapter/utils/rand.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // RandomCreateBytes generate random []byte by specify chars. diff --git a/pkg/adapter/utils/rand_test.go b/adapter/utils/rand_test.go similarity index 100% rename from pkg/adapter/utils/rand_test.go rename to adapter/utils/rand_test.go diff --git a/pkg/adapter/utils/safemap.go b/adapter/utils/safemap.go similarity index 97% rename from pkg/adapter/utils/safemap.go rename to adapter/utils/safemap.go index 6771aca4..bb50f3cd 100644 --- a/pkg/adapter/utils/safemap.go +++ b/adapter/utils/safemap.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // BeeMap is a map with lock diff --git a/pkg/adapter/utils/safemap_test.go b/adapter/utils/safemap_test.go similarity index 100% rename from pkg/adapter/utils/safemap_test.go rename to adapter/utils/safemap_test.go diff --git a/pkg/adapter/utils/slice.go b/adapter/utils/slice.go similarity index 98% rename from pkg/adapter/utils/slice.go rename to adapter/utils/slice.go index a5b852b9..44b782b4 100644 --- a/pkg/adapter/utils/slice.go +++ b/adapter/utils/slice.go @@ -15,7 +15,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) type reducetype func(interface{}) interface{} diff --git a/pkg/adapter/utils/slice_test.go b/adapter/utils/slice_test.go similarity index 100% rename from pkg/adapter/utils/slice_test.go rename to adapter/utils/slice_test.go diff --git a/pkg/adapter/utils/utils.go b/adapter/utils/utils.go similarity index 76% rename from pkg/adapter/utils/utils.go rename to adapter/utils/utils.go index 21ed49dc..8ba21bc4 100644 --- a/pkg/adapter/utils/utils.go +++ b/adapter/utils/utils.go @@ -1,7 +1,7 @@ package utils import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // GetGOPATHs returns all paths in GOPATH variable. diff --git a/pkg/adapter/validation/util.go b/adapter/validation/util.go similarity index 97% rename from pkg/adapter/validation/util.go rename to adapter/validation/util.go index fb5370bf..431ce80d 100644 --- a/pkg/adapter/validation/util.go +++ b/adapter/validation/util.go @@ -17,7 +17,7 @@ package validation import ( "reflect" - "github.com/astaxie/beego/pkg/core/validation" + "github.com/astaxie/beego/core/validation" ) const ( diff --git a/pkg/adapter/validation/validation.go b/adapter/validation/validation.go similarity index 99% rename from pkg/adapter/validation/validation.go rename to adapter/validation/validation.go index e95fd408..e90c9f5b 100644 --- a/pkg/adapter/validation/validation.go +++ b/adapter/validation/validation.go @@ -50,7 +50,7 @@ import ( "fmt" "regexp" - "github.com/astaxie/beego/pkg/core/validation" + "github.com/astaxie/beego/core/validation" ) // ValidFormer valid interface diff --git a/pkg/adapter/validation/validation_test.go b/adapter/validation/validation_test.go similarity index 100% rename from pkg/adapter/validation/validation_test.go rename to adapter/validation/validation_test.go diff --git a/pkg/adapter/validation/validators.go b/adapter/validation/validators.go similarity index 99% rename from pkg/adapter/validation/validators.go rename to adapter/validation/validators.go index 152e8aef..5cd5d286 100644 --- a/pkg/adapter/validation/validators.go +++ b/adapter/validation/validators.go @@ -17,7 +17,7 @@ package validation import ( "sync" - "github.com/astaxie/beego/pkg/core/validation" + "github.com/astaxie/beego/core/validation" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/build/gobuild-sample.sh b/build/gobuild-sample.sh deleted file mode 100755 index 031eafc2..00000000 --- a/build/gobuild-sample.sh +++ /dev/null @@ -1,112 +0,0 @@ -#!/bin/bash - -# WARNING: DO NOT EDIT, THIS FILE IS PROBABLY A COPY -# -# The original version of this file is located in the https://github.com/istio/common-files repo. -# If you're looking at this file in a different repo and want to make a change, please go to the -# common-files repo, make the change there and check it in. Then come back to this repo and run -# "make update-common". - -# Copyright Istio Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script builds and version stamps the output - -# adatp to beego - -VERBOSE=${VERBOSE:-"0"} -V="" -if [[ "${VERBOSE}" == "1" ]];then - V="-x" - set -x -fi - -SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -OUT=${1:?"output path"} -shift - -set -e - -BUILD_GOOS=${GOOS:-linux} -BUILD_GOARCH=${GOARCH:-amd64} -GOBINARY=${GOBINARY:-go} -GOPKG="$GOPATH/pkg" -BUILDINFO=${BUILDINFO:-""} -STATIC=${STATIC:-1} -LDFLAGS=${LDFLAGS:--extldflags -static} -GOBUILDFLAGS=${GOBUILDFLAGS:-""} -# Split GOBUILDFLAGS by spaces into an array called GOBUILDFLAGS_ARRAY. -IFS=' ' read -r -a GOBUILDFLAGS_ARRAY <<< "$GOBUILDFLAGS" - -GCFLAGS=${GCFLAGS:-} -export CGO_ENABLED=0 - -if [[ "${STATIC}" != "1" ]];then - LDFLAGS="" -fi - -# gather buildinfo if not already provided -# For a release build BUILDINFO should be produced -# at the beginning of the build and used throughout -if [[ -z ${BUILDINFO} ]];then - BUILDINFO=$(mktemp) - "${SCRIPTPATH}/report_build_info.sh" > "${BUILDINFO}" -fi - - -# BUILD LD_EXTRAFLAGS -LD_EXTRAFLAGS="" - -while read -r line; do - LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X ${line}" -done < "${BUILDINFO}" - -# verify go version before build -# NB. this was copied verbatim from Kubernetes hack -minimum_go_version=go1.13 # supported patterns: go1.x, go1.x.x (x should be a number) -IFS=" " read -ra go_version <<< "$(${GOBINARY} version)" -if [[ "${minimum_go_version}" != $(echo -e "${minimum_go_version}\n${go_version[2]}" | sort -s -t. -k 1,1 -k 2,2n -k 3,3n | head -n1) && "${go_version[2]}" != "devel" ]]; then - echo "Warning: Detected that you are using an older version of the Go compiler. Beego requires ${minimum_go_version} or greater." -fi - -CURRENT_BRANCH=$(git branch | grep '*') -CURRENT_BRANCH=${CURRENT_BRANCH:2} - -BUILD_TIME=$(date +%Y-%m-%d--%T) - -LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.GoVersion=${go_version[2]:2}" -LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.GitBranch=${CURRENT_BRANCH}" -LD_EXTRAFLAGS="${LD_EXTRAFLAGS} -X github.com/astaxie/beego.BuildTime=$BUILD_TIME" - -OPTIMIZATION_FLAGS="-trimpath" -if [ "${DEBUG}" == "1" ]; then - OPTIMIZATION_FLAGS="" -fi - - - -echo "BUILD_GOARCH: $BUILD_GOARCH" -echo "GOPKG: $GOPKG" -echo "LD_EXTRAFLAGS: $LD_EXTRAFLAGS" -echo "GO_VERSION: ${go_version[2]}" -echo "BRANCH: $CURRENT_BRANCH" -echo "BUILD_TIME: $BUILD_TIME" - -time GOOS=${BUILD_GOOS} GOARCH=${BUILD_GOARCH} ${GOBINARY} build \ - ${V} "${GOBUILDFLAGS_ARRAY[@]}" ${GCFLAGS:+-gcflags "${GCFLAGS}"} \ - -o "${OUT}" \ - ${OPTIMIZATION_FLAGS} \ - -pkgdir="${GOPKG}/${BUILD_GOOS}_${BUILD_GOARCH}" \ - -ldflags "${LDFLAGS} ${LD_EXTRAFLAGS}" "${@}" \ No newline at end of file diff --git a/build/report_build_info.sh b/build/report_build_info.sh deleted file mode 100755 index 65ba3748..00000000 --- a/build/report_build_info.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -# WARNING: DO NOT EDIT, THIS FILE IS PROBABLY A COPY -# -# The original version of this file is located in the https://github.com/istio/common-files repo. -# If you're looking at this file in a different repo and want to make a change, please go to the -# common-files repo, make the change there and check it in. Then come back to this repo and run -# "make update-common". - -# Copyright Istio Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# adapt to beego - -if BUILD_GIT_REVISION=$(git rev-parse HEAD 2> /dev/null); then - if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then - BUILD_GIT_REVISION=${BUILD_GIT_REVISION}"-dirty" - fi -else - BUILD_GIT_REVISION=unknown -fi - -# Check for local changes -if git diff-index --quiet HEAD --; then - tree_status="Clean" -else - tree_status="Modified" -fi - -# security wanted VERSION='unknown' -VERSION="${BUILD_GIT_REVISION}" -if [[ -n ${BEEGO_VERSION} ]]; then - VERSION="${BEEGO_VERSION}" -fi - -GIT_DESCRIBE_TAG=$(git describe --tags) - -echo "github.com/astaxie/beego.BuildVersion=${VERSION}" -echo "github.com/astaxie/beego.BuildGitRevision=${BUILD_GIT_REVISION}" -echo "github.com/astaxie/beego.BuildStatus=${tree_status}" -echo "github.com/astaxie/beego.BuildTag=${GIT_DESCRIBE_TAG}" \ No newline at end of file diff --git a/pkg/build_info.go b/build_info.go similarity index 98% rename from pkg/build_info.go rename to build_info.go index 778856c6..33287090 100644 --- a/pkg/build_info.go +++ b/build_info.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pkg +package beego var ( BuildVersion string diff --git a/pkg/client/cache/README.md b/client/cache/README.md similarity index 100% rename from pkg/client/cache/README.md rename to client/cache/README.md diff --git a/pkg/client/cache/cache.go b/client/cache/cache.go similarity index 100% rename from pkg/client/cache/cache.go rename to client/cache/cache.go diff --git a/pkg/client/cache/cache_test.go b/client/cache/cache_test.go similarity index 100% rename from pkg/client/cache/cache_test.go rename to client/cache/cache_test.go diff --git a/pkg/client/cache/conv.go b/client/cache/conv.go similarity index 100% rename from pkg/client/cache/conv.go rename to client/cache/conv.go diff --git a/pkg/client/cache/conv_test.go b/client/cache/conv_test.go similarity index 100% rename from pkg/client/cache/conv_test.go rename to client/cache/conv_test.go diff --git a/pkg/client/cache/file.go b/client/cache/file.go similarity index 100% rename from pkg/client/cache/file.go rename to client/cache/file.go diff --git a/pkg/client/cache/memcache/memcache.go b/client/cache/memcache/memcache.go similarity index 98% rename from pkg/client/cache/memcache/memcache.go rename to client/cache/memcache/memcache.go index d3b7e767..f3774571 100644 --- a/pkg/client/cache/memcache/memcache.go +++ b/client/cache/memcache/memcache.go @@ -38,7 +38,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) // Cache Memcache adapter. diff --git a/pkg/client/cache/memcache/memcache_test.go b/client/cache/memcache/memcache_test.go similarity index 98% rename from pkg/client/cache/memcache/memcache_test.go rename to client/cache/memcache/memcache_test.go index 64679671..bc8936a7 100644 --- a/pkg/client/cache/memcache/memcache_test.go +++ b/client/cache/memcache/memcache_test.go @@ -24,7 +24,7 @@ import ( _ "github.com/bradfitz/gomemcache/memcache" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) func TestMemcacheCache(t *testing.T) { diff --git a/pkg/client/cache/memory.go b/client/cache/memory.go similarity index 100% rename from pkg/client/cache/memory.go rename to client/cache/memory.go diff --git a/pkg/client/cache/redis/redis.go b/client/cache/redis/redis.go similarity index 99% rename from pkg/client/cache/redis/redis.go rename to client/cache/redis/redis.go index e2785297..34059835 100644 --- a/pkg/client/cache/redis/redis.go +++ b/client/cache/redis/redis.go @@ -40,7 +40,7 @@ import ( "github.com/gomodule/redigo/redis" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) var ( diff --git a/pkg/client/cache/redis/redis_test.go b/client/cache/redis/redis_test.go similarity index 98% rename from pkg/client/cache/redis/redis_test.go rename to client/cache/redis/redis_test.go index f7308365..f82b2c40 100644 --- a/pkg/client/cache/redis/redis_test.go +++ b/client/cache/redis/redis_test.go @@ -24,7 +24,7 @@ import ( "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) func TestRedisCache(t *testing.T) { diff --git a/pkg/client/cache/ssdb/ssdb.go b/client/cache/ssdb/ssdb.go similarity index 99% rename from pkg/client/cache/ssdb/ssdb.go rename to client/cache/ssdb/ssdb.go index 2e4f2815..1acee861 100644 --- a/pkg/client/cache/ssdb/ssdb.go +++ b/client/cache/ssdb/ssdb.go @@ -10,7 +10,7 @@ import ( "github.com/ssdb/gossdb/ssdb" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) // Cache SSDB adapter diff --git a/pkg/client/cache/ssdb/ssdb_test.go b/client/cache/ssdb/ssdb_test.go similarity index 98% rename from pkg/client/cache/ssdb/ssdb_test.go rename to client/cache/ssdb/ssdb_test.go index f675d1ab..cebaa975 100644 --- a/pkg/client/cache/ssdb/ssdb_test.go +++ b/client/cache/ssdb/ssdb_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/client/cache" ) func TestSsdbcacheCache(t *testing.T) { diff --git a/pkg/client/httplib/README.md b/client/httplib/README.md similarity index 100% rename from pkg/client/httplib/README.md rename to client/httplib/README.md diff --git a/pkg/client/httplib/filter.go b/client/httplib/filter.go similarity index 100% rename from pkg/client/httplib/filter.go rename to client/httplib/filter.go diff --git a/pkg/client/httplib/filter/opentracing/filter.go b/client/httplib/filter/opentracing/filter.go similarity index 98% rename from pkg/client/httplib/filter/opentracing/filter.go rename to client/httplib/filter/opentracing/filter.go index 93376843..765a82a9 100644 --- a/pkg/client/httplib/filter/opentracing/filter.go +++ b/client/httplib/filter/opentracing/filter.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" diff --git a/pkg/client/httplib/filter/opentracing/filter_test.go b/client/httplib/filter/opentracing/filter_test.go similarity index 96% rename from pkg/client/httplib/filter/opentracing/filter_test.go rename to client/httplib/filter/opentracing/filter_test.go index 46937803..7281f93f 100644 --- a/pkg/client/httplib/filter/opentracing/filter_test.go +++ b/client/httplib/filter/opentracing/filter_test.go @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/client/httplib/filter/prometheus/filter.go b/client/httplib/filter/prometheus/filter.go similarity index 97% rename from pkg/client/httplib/filter/prometheus/filter.go rename to client/httplib/filter/prometheus/filter.go index b4a418e0..ce88b70e 100644 --- a/pkg/client/httplib/filter/prometheus/filter.go +++ b/client/httplib/filter/prometheus/filter.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" ) type FilterChainBuilder struct { diff --git a/pkg/client/httplib/filter/prometheus/filter_test.go b/client/httplib/filter/prometheus/filter_test.go similarity index 96% rename from pkg/client/httplib/filter/prometheus/filter_test.go rename to client/httplib/filter/prometheus/filter_test.go index 2964a4a2..46edc3d2 100644 --- a/pkg/client/httplib/filter/prometheus/filter_test.go +++ b/client/httplib/filter/prometheus/filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/client/httplib/httplib.go b/client/httplib/httplib.go similarity index 100% rename from pkg/client/httplib/httplib.go rename to client/httplib/httplib.go diff --git a/pkg/client/httplib/httplib_test.go b/client/httplib/httplib_test.go similarity index 100% rename from pkg/client/httplib/httplib_test.go rename to client/httplib/httplib_test.go diff --git a/pkg/client/httplib/testing/client.go b/client/httplib/testing/client.go similarity index 97% rename from pkg/client/httplib/testing/client.go rename to client/httplib/testing/client.go index 107b28cc..34e49f2d 100644 --- a/pkg/client/httplib/testing/client.go +++ b/client/httplib/testing/client.go @@ -15,7 +15,7 @@ package testing import ( - "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/client/httplib" ) var port = "" diff --git a/pkg/client/orm/README.md b/client/orm/README.md similarity index 100% rename from pkg/client/orm/README.md rename to client/orm/README.md diff --git a/pkg/client/orm/cmd.go b/client/orm/cmd.go similarity index 100% rename from pkg/client/orm/cmd.go rename to client/orm/cmd.go diff --git a/pkg/client/orm/cmd_utils.go b/client/orm/cmd_utils.go similarity index 100% rename from pkg/client/orm/cmd_utils.go rename to client/orm/cmd_utils.go diff --git a/pkg/client/orm/db.go b/client/orm/db.go similarity index 99% rename from pkg/client/orm/db.go rename to client/orm/db.go index 2bd1308f..b103d218 100644 --- a/pkg/client/orm/db.go +++ b/client/orm/db.go @@ -22,7 +22,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/client/orm/hints" ) const ( diff --git a/pkg/client/orm/db_alias.go b/client/orm/db_alias.go similarity index 100% rename from pkg/client/orm/db_alias.go rename to client/orm/db_alias.go diff --git a/pkg/client/orm/db_alias_test.go b/client/orm/db_alias_test.go similarity index 100% rename from pkg/client/orm/db_alias_test.go rename to client/orm/db_alias_test.go diff --git a/pkg/client/orm/db_mysql.go b/client/orm/db_mysql.go similarity index 100% rename from pkg/client/orm/db_mysql.go rename to client/orm/db_mysql.go diff --git a/pkg/client/orm/db_oracle.go b/client/orm/db_oracle.go similarity index 98% rename from pkg/client/orm/db_oracle.go rename to client/orm/db_oracle.go index 7c1bf1b3..cb0d5052 100644 --- a/pkg/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -18,7 +18,7 @@ import ( "fmt" "strings" - "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/client/orm/hints" ) // oracle operators. diff --git a/pkg/client/orm/db_postgres.go b/client/orm/db_postgres.go similarity index 100% rename from pkg/client/orm/db_postgres.go rename to client/orm/db_postgres.go diff --git a/pkg/client/orm/db_sqlite.go b/client/orm/db_sqlite.go similarity index 99% rename from pkg/client/orm/db_sqlite.go rename to client/orm/db_sqlite.go index 6d7a5617..961f2535 100644 --- a/pkg/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -21,7 +21,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/client/orm/hints" ) // sqlite operators. diff --git a/pkg/client/orm/db_tables.go b/client/orm/db_tables.go similarity index 100% rename from pkg/client/orm/db_tables.go rename to client/orm/db_tables.go diff --git a/pkg/client/orm/db_tidb.go b/client/orm/db_tidb.go similarity index 100% rename from pkg/client/orm/db_tidb.go rename to client/orm/db_tidb.go diff --git a/pkg/client/orm/db_utils.go b/client/orm/db_utils.go similarity index 100% rename from pkg/client/orm/db_utils.go rename to client/orm/db_utils.go diff --git a/pkg/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go similarity index 99% rename from pkg/client/orm/do_nothing_orm.go rename to client/orm/do_nothing_orm.go index 42775b54..fc5b2159 100644 --- a/pkg/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation diff --git a/pkg/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go similarity index 100% rename from pkg/client/orm/do_nothing_orm_test.go rename to client/orm/do_nothing_orm_test.go diff --git a/pkg/client/orm/filter.go b/client/orm/filter.go similarity index 100% rename from pkg/client/orm/filter.go rename to client/orm/filter.go diff --git a/pkg/client/orm/filter/bean/default_value_filter.go b/client/orm/filter/bean/default_value_filter.go similarity index 97% rename from pkg/client/orm/filter/bean/default_value_filter.go rename to client/orm/filter/bean/default_value_filter.go index 7da9408d..3dac5c74 100644 --- a/pkg/client/orm/filter/bean/default_value_filter.go +++ b/client/orm/filter/bean/default_value_filter.go @@ -19,10 +19,10 @@ import ( "reflect" "strings" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/core/bean" + "github.com/astaxie/beego/client/orm" + "github.com/astaxie/beego/core/bean" ) // DefaultValueFilterChainBuilder only works for InsertXXX method, diff --git a/pkg/client/orm/filter/bean/default_value_filter_test.go b/client/orm/filter/bean/default_value_filter_test.go similarity index 97% rename from pkg/client/orm/filter/bean/default_value_filter_test.go rename to client/orm/filter/bean/default_value_filter_test.go index fde7abf8..2a6ed1f4 100644 --- a/pkg/client/orm/filter/bean/default_value_filter_test.go +++ b/client/orm/filter/bean/default_value_filter_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go similarity index 98% rename from pkg/client/orm/filter/opentracing/filter.go rename to client/orm/filter/opentracing/filter.go index 1a2ee541..7f9658b4 100644 --- a/pkg/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -20,7 +20,7 @@ import ( "github.com/opentracing/opentracing-go" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // FilterChainBuilder provides an extension point diff --git a/pkg/client/orm/filter/opentracing/filter_test.go b/client/orm/filter/opentracing/filter_test.go similarity index 96% rename from pkg/client/orm/filter/opentracing/filter_test.go rename to client/orm/filter/opentracing/filter_test.go index 9e89e5da..428dacda 100644 --- a/pkg/client/orm/filter/opentracing/filter_test.go +++ b/client/orm/filter/opentracing/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/opentracing/opentracing-go" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go similarity index 98% rename from pkg/client/orm/filter/prometheus/filter.go rename to client/orm/filter/prometheus/filter.go index 2d819ef7..e74e946a 100644 --- a/pkg/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -22,7 +22,7 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) // FilterChainBuilder is an extension point, diff --git a/pkg/client/orm/filter/prometheus/filter_test.go b/client/orm/filter/prometheus/filter_test.go similarity index 97% rename from pkg/client/orm/filter/prometheus/filter_test.go rename to client/orm/filter/prometheus/filter_test.go index 0368d321..72b16038 100644 --- a/pkg/client/orm/filter/prometheus/filter_test.go +++ b/client/orm/filter/prometheus/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/client/orm" ) func TestFilterChainBuilder_FilterChain1(t *testing.T) { diff --git a/pkg/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go similarity index 99% rename from pkg/client/orm/filter_orm_decorator.go rename to client/orm/filter_orm_decorator.go index 729c1698..9f837cba 100644 --- a/pkg/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) const ( diff --git a/pkg/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go similarity index 99% rename from pkg/client/orm/filter_orm_decorator_test.go rename to client/orm/filter_orm_decorator_test.go index 40a3fa2e..671ca001 100644 --- a/pkg/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -21,7 +21,7 @@ import ( "sync" "testing" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" "github.com/stretchr/testify/assert" ) diff --git a/pkg/client/orm/filter_test.go b/client/orm/filter_test.go similarity index 100% rename from pkg/client/orm/filter_test.go rename to client/orm/filter_test.go diff --git a/pkg/client/orm/hints/db_hints.go b/client/orm/hints/db_hints.go similarity index 98% rename from pkg/client/orm/hints/db_hints.go rename to client/orm/hints/db_hints.go index c6180529..7bfe8eb0 100644 --- a/pkg/client/orm/hints/db_hints.go +++ b/client/orm/hints/db_hints.go @@ -15,7 +15,7 @@ package hints import ( - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) const ( diff --git a/pkg/client/orm/hints/db_hints_test.go b/client/orm/hints/db_hints_test.go similarity index 100% rename from pkg/client/orm/hints/db_hints_test.go rename to client/orm/hints/db_hints_test.go diff --git a/pkg/client/orm/invocation.go b/client/orm/invocation.go similarity index 100% rename from pkg/client/orm/invocation.go rename to client/orm/invocation.go diff --git a/pkg/client/orm/migration/ddl.go b/client/orm/migration/ddl.go similarity index 99% rename from pkg/client/orm/migration/ddl.go rename to client/orm/migration/ddl.go index e351b4cc..a396d39a 100644 --- a/pkg/client/orm/migration/ddl.go +++ b/client/orm/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // Index struct defines the structure of Index Columns diff --git a/pkg/client/orm/migration/doc.go b/client/orm/migration/doc.go similarity index 100% rename from pkg/client/orm/migration/doc.go rename to client/orm/migration/doc.go diff --git a/pkg/client/orm/migration/migration.go b/client/orm/migration/migration.go similarity index 98% rename from pkg/client/orm/migration/migration.go rename to client/orm/migration/migration.go index 4a56be58..aeea12c6 100644 --- a/pkg/client/orm/migration/migration.go +++ b/client/orm/migration/migration.go @@ -33,8 +33,8 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/client/orm" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/client/orm" + "github.com/astaxie/beego/core/logs" ) // const the data format for the bee generate migration datatype diff --git a/pkg/client/orm/model_utils_test.go b/client/orm/model_utils_test.go similarity index 100% rename from pkg/client/orm/model_utils_test.go rename to client/orm/model_utils_test.go diff --git a/pkg/client/orm/models.go b/client/orm/models.go similarity index 100% rename from pkg/client/orm/models.go rename to client/orm/models.go diff --git a/pkg/client/orm/models_boot.go b/client/orm/models_boot.go similarity index 100% rename from pkg/client/orm/models_boot.go rename to client/orm/models_boot.go diff --git a/pkg/client/orm/models_fields.go b/client/orm/models_fields.go similarity index 100% rename from pkg/client/orm/models_fields.go rename to client/orm/models_fields.go diff --git a/pkg/client/orm/models_info_f.go b/client/orm/models_info_f.go similarity index 100% rename from pkg/client/orm/models_info_f.go rename to client/orm/models_info_f.go diff --git a/pkg/client/orm/models_info_m.go b/client/orm/models_info_m.go similarity index 100% rename from pkg/client/orm/models_info_m.go rename to client/orm/models_info_m.go diff --git a/pkg/client/orm/models_test.go b/client/orm/models_test.go similarity index 97% rename from pkg/client/orm/models_test.go rename to client/orm/models_test.go index f0044f6d..d5aa2fa0 100644 --- a/pkg/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -318,7 +318,7 @@ type Post struct { Created time.Time `orm:"auto_now_add"` Updated time.Time `orm:"auto_now"` UpdatedPrecision time.Time `orm:"auto_now;type(datetime);precision(4)"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.PostTags)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/client/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { @@ -376,7 +376,7 @@ type Group struct { type Permission struct { ID int `orm:"column(id)"` Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.GroupPermissions)"` + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/client/orm.GroupPermissions)"` } type GroupPermissions struct { @@ -485,7 +485,7 @@ var ( usage: - go get -u github.com/astaxie/beego/pkg/client/orm + go get -u github.com/astaxie/beego/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq @@ -495,20 +495,20 @@ var ( mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/pkg/client/orm + go test -v github.com/astaxie/beego/client/orm #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/pkg/client/orm + go test -v github.com/astaxie/beego/client/orm #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/pkg/client/orm + go test -v github.com/astaxie/beego/client/orm #### TiDB export ORM_DRIVER=tidb diff --git a/pkg/client/orm/models_utils.go b/client/orm/models_utils.go similarity index 100% rename from pkg/client/orm/models_utils.go rename to client/orm/models_utils.go diff --git a/pkg/client/orm/models_utils_test.go b/client/orm/models_utils_test.go similarity index 100% rename from pkg/client/orm/models_utils_test.go rename to client/orm/models_utils_test.go diff --git a/pkg/client/orm/orm.go b/client/orm/orm.go similarity index 98% rename from pkg/client/orm/orm.go rename to client/orm/orm.go index 7d1aace0..a83faeb2 100644 --- a/pkg/client/orm/orm.go +++ b/client/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/astaxie/beego/pkg/client/orm" +// "github.com/astaxie/beego/client/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // @@ -62,10 +62,10 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/client/orm/hints" + "github.com/astaxie/beego/core/utils" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // DebugQueries define the debug diff --git a/pkg/client/orm/orm_conds.go b/client/orm/orm_conds.go similarity index 100% rename from pkg/client/orm/orm_conds.go rename to client/orm/orm_conds.go diff --git a/pkg/client/orm/orm_log.go b/client/orm/orm_log.go similarity index 100% rename from pkg/client/orm/orm_log.go rename to client/orm/orm_log.go diff --git a/pkg/client/orm/orm_object.go b/client/orm/orm_object.go similarity index 100% rename from pkg/client/orm/orm_object.go rename to client/orm/orm_object.go diff --git a/pkg/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go similarity index 100% rename from pkg/client/orm/orm_querym2m.go rename to client/orm/orm_querym2m.go diff --git a/pkg/client/orm/orm_queryset.go b/client/orm/orm_queryset.go similarity index 99% rename from pkg/client/orm/orm_queryset.go rename to client/orm/orm_queryset.go index 906505de..ed223e24 100644 --- a/pkg/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,7 +18,7 @@ import ( "context" "fmt" - "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/client/orm/hints" ) type colValue struct { diff --git a/pkg/client/orm/orm_raw.go b/client/orm/orm_raw.go similarity index 100% rename from pkg/client/orm/orm_raw.go rename to client/orm/orm_raw.go diff --git a/pkg/client/orm/orm_test.go b/client/orm/orm_test.go similarity index 99% rename from pkg/client/orm/orm_test.go rename to client/orm/orm_test.go index bd92f46f..565f6c60 100644 --- a/pkg/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -31,7 +31,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/client/orm/hints" "github.com/stretchr/testify/assert" ) diff --git a/pkg/client/orm/qb.go b/client/orm/qb.go similarity index 100% rename from pkg/client/orm/qb.go rename to client/orm/qb.go diff --git a/pkg/client/orm/qb_mysql.go b/client/orm/qb_mysql.go similarity index 100% rename from pkg/client/orm/qb_mysql.go rename to client/orm/qb_mysql.go diff --git a/pkg/client/orm/qb_postgres.go b/client/orm/qb_postgres.go similarity index 100% rename from pkg/client/orm/qb_postgres.go rename to client/orm/qb_postgres.go diff --git a/pkg/client/orm/qb_tidb.go b/client/orm/qb_tidb.go similarity index 100% rename from pkg/client/orm/qb_tidb.go rename to client/orm/qb_tidb.go diff --git a/pkg/client/orm/types.go b/client/orm/types.go similarity index 99% rename from pkg/client/orm/types.go rename to client/orm/types.go index e43bfd2c..34c61d51 100644 --- a/pkg/client/orm/types.go +++ b/client/orm/types.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // TableNaming is usually used by model diff --git a/pkg/client/orm/utils.go b/client/orm/utils.go similarity index 100% rename from pkg/client/orm/utils.go rename to client/orm/utils.go diff --git a/pkg/client/orm/utils_test.go b/client/orm/utils_test.go similarity index 100% rename from pkg/client/orm/utils_test.go rename to client/orm/utils_test.go diff --git a/pkg/core/bean/context.go b/core/bean/context.go similarity index 100% rename from pkg/core/bean/context.go rename to core/bean/context.go diff --git a/pkg/core/bean/doc.go b/core/bean/doc.go similarity index 100% rename from pkg/core/bean/doc.go rename to core/bean/doc.go diff --git a/pkg/core/bean/factory.go b/core/bean/factory.go similarity index 100% rename from pkg/core/bean/factory.go rename to core/bean/factory.go diff --git a/pkg/core/bean/metadata.go b/core/bean/metadata.go similarity index 100% rename from pkg/core/bean/metadata.go rename to core/bean/metadata.go diff --git a/pkg/core/bean/tag_auto_wire_bean_factory.go b/core/bean/tag_auto_wire_bean_factory.go similarity index 99% rename from pkg/core/bean/tag_auto_wire_bean_factory.go rename to core/bean/tag_auto_wire_bean_factory.go index 595b3a02..b88a42ff 100644 --- a/pkg/core/bean/tag_auto_wire_bean_factory.go +++ b/core/bean/tag_auto_wire_bean_factory.go @@ -22,7 +22,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) const DefaultValueTagKey = "default" diff --git a/pkg/core/bean/tag_auto_wire_bean_factory_test.go b/core/bean/tag_auto_wire_bean_factory_test.go similarity index 100% rename from pkg/core/bean/tag_auto_wire_bean_factory_test.go rename to core/bean/tag_auto_wire_bean_factory_test.go diff --git a/pkg/core/bean/time_type_adapter.go b/core/bean/time_type_adapter.go similarity index 100% rename from pkg/core/bean/time_type_adapter.go rename to core/bean/time_type_adapter.go diff --git a/pkg/core/bean/time_type_adapter_test.go b/core/bean/time_type_adapter_test.go similarity index 100% rename from pkg/core/bean/time_type_adapter_test.go rename to core/bean/time_type_adapter_test.go diff --git a/pkg/core/bean/type_adapter.go b/core/bean/type_adapter.go similarity index 100% rename from pkg/core/bean/type_adapter.go rename to core/bean/type_adapter.go diff --git a/pkg/core/config/base_config_test.go b/core/config/base_config_test.go similarity index 100% rename from pkg/core/config/base_config_test.go rename to core/config/base_config_test.go diff --git a/pkg/core/config/config.go b/core/config/config.go similarity index 100% rename from pkg/core/config/config.go rename to core/config/config.go diff --git a/pkg/core/config/config_test.go b/core/config/config_test.go similarity index 100% rename from pkg/core/config/config_test.go rename to core/config/config_test.go diff --git a/pkg/core/config/env/env.go b/core/config/env/env.go similarity index 98% rename from pkg/core/config/env/env.go rename to core/config/env/env.go index 0cf1582b..d3903d74 100644 --- a/pkg/core/config/env/env.go +++ b/core/config/env/env.go @@ -21,7 +21,7 @@ import ( "os" "strings" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) var env *utils.BeeMap diff --git a/pkg/core/config/env/env_test.go b/core/config/env/env_test.go similarity index 100% rename from pkg/core/config/env/env_test.go rename to core/config/env/env_test.go diff --git a/pkg/core/config/etcd/config.go b/core/config/etcd/config.go similarity index 98% rename from pkg/core/config/etcd/config.go rename to core/config/etcd/config.go index 278cbaa9..37dba9de 100644 --- a/pkg/core/config/etcd/config.go +++ b/core/config/etcd/config.go @@ -26,8 +26,8 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc" - "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" ) const etcdOpts = "etcdOpts" diff --git a/pkg/core/config/etcd/config_test.go b/core/config/etcd/config_test.go similarity index 100% rename from pkg/core/config/etcd/config_test.go rename to core/config/etcd/config_test.go diff --git a/pkg/core/config/fake.go b/core/config/fake.go similarity index 100% rename from pkg/core/config/fake.go rename to core/config/fake.go diff --git a/pkg/core/config/ini.go b/core/config/ini.go similarity index 100% rename from pkg/core/config/ini.go rename to core/config/ini.go diff --git a/pkg/core/config/ini_test.go b/core/config/ini_test.go similarity index 100% rename from pkg/core/config/ini_test.go rename to core/config/ini_test.go diff --git a/pkg/core/config/json/json.go b/core/config/json/json.go similarity index 98% rename from pkg/core/config/json/json.go rename to core/config/json/json.go index 66546d89..f58e70f5 100644 --- a/pkg/core/config/json/json.go +++ b/core/config/json/json.go @@ -27,8 +27,8 @@ import ( "github.com/mitchellh/mapstructure" - "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" ) // JSONConfig is a json config parser and implements Config interface. diff --git a/pkg/core/config/json/json_test.go b/core/config/json/json_test.go similarity index 99% rename from pkg/core/config/json/json_test.go rename to core/config/json/json_test.go index d4601e39..b615c19a 100644 --- a/pkg/core/config/json/json_test.go +++ b/core/config/json/json_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) func TestJsonStartsWithArray(t *testing.T) { diff --git a/pkg/core/config/xml/xml.go b/core/config/xml/xml.go similarity index 98% rename from pkg/core/config/xml/xml.go rename to core/config/xml/xml.go index 47d4b8c8..3b1a7051 100644 --- a/pkg/core/config/xml/xml.go +++ b/core/config/xml/xml.go @@ -42,8 +42,8 @@ import ( "github.com/mitchellh/mapstructure" - "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" "github.com/beego/x2j" ) diff --git a/pkg/core/config/xml/xml_test.go b/core/config/xml/xml_test.go similarity index 98% rename from pkg/core/config/xml/xml_test.go rename to core/config/xml/xml_test.go index b110f813..0266e270 100644 --- a/pkg/core/config/xml/xml_test.go +++ b/core/config/xml/xml_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) func TestXML(t *testing.T) { diff --git a/pkg/core/config/yaml/yaml.go b/core/config/yaml/yaml.go similarity index 99% rename from pkg/core/config/yaml/yaml.go rename to core/config/yaml/yaml.go index 5a4bfad0..6d9abb4e 100644 --- a/pkg/core/config/yaml/yaml.go +++ b/core/config/yaml/yaml.go @@ -41,11 +41,10 @@ import ( "strings" "sync" + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" "github.com/beego/goyaml2" "gopkg.in/yaml.v2" - - "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/core/logs" ) // Config is a yaml config parser and implements Config interface. diff --git a/pkg/core/config/yaml/yaml_test.go b/core/config/yaml/yaml_test.go similarity index 98% rename from pkg/core/config/yaml/yaml_test.go rename to core/config/yaml/yaml_test.go index 130ce6a2..a7c3a92e 100644 --- a/pkg/core/config/yaml/yaml_test.go +++ b/core/config/yaml/yaml_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/core/config" + "github.com/astaxie/beego/core/config" ) func TestYaml(t *testing.T) { diff --git a/pkg/core/governor/command.go b/core/governor/command.go similarity index 100% rename from pkg/core/governor/command.go rename to core/governor/command.go diff --git a/pkg/core/governor/healthcheck.go b/core/governor/healthcheck.go similarity index 100% rename from pkg/core/governor/healthcheck.go rename to core/governor/healthcheck.go diff --git a/pkg/core/governor/profile.go b/core/governor/profile.go similarity index 99% rename from pkg/core/governor/profile.go rename to core/governor/profile.go index de6e1995..17f1f375 100644 --- a/pkg/core/governor/profile.go +++ b/core/governor/profile.go @@ -26,7 +26,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) var startTime = time.Now() diff --git a/pkg/core/governor/profile_test.go b/core/governor/profile_test.go similarity index 100% rename from pkg/core/governor/profile_test.go rename to core/governor/profile_test.go diff --git a/pkg/core/logs/README.md b/core/logs/README.md similarity index 100% rename from pkg/core/logs/README.md rename to core/logs/README.md diff --git a/pkg/core/logs/access_log.go b/core/logs/access_log.go similarity index 100% rename from pkg/core/logs/access_log.go rename to core/logs/access_log.go diff --git a/pkg/core/logs/access_log_test.go b/core/logs/access_log_test.go similarity index 100% rename from pkg/core/logs/access_log_test.go rename to core/logs/access_log_test.go diff --git a/pkg/core/logs/alils/alils.go b/core/logs/alils/alils.go similarity index 98% rename from pkg/core/logs/alils/alils.go rename to core/logs/alils/alils.go index 812b1b3b..484d31e4 100644 --- a/pkg/core/logs/alils/alils.go +++ b/core/logs/alils/alils.go @@ -9,7 +9,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) const ( diff --git a/pkg/core/logs/alils/config.go b/core/logs/alils/config.go similarity index 100% rename from pkg/core/logs/alils/config.go rename to core/logs/alils/config.go diff --git a/pkg/core/logs/alils/log.pb.go b/core/logs/alils/log.pb.go similarity index 100% rename from pkg/core/logs/alils/log.pb.go rename to core/logs/alils/log.pb.go diff --git a/pkg/core/logs/alils/log_config.go b/core/logs/alils/log_config.go similarity index 100% rename from pkg/core/logs/alils/log_config.go rename to core/logs/alils/log_config.go diff --git a/pkg/core/logs/alils/log_project.go b/core/logs/alils/log_project.go similarity index 100% rename from pkg/core/logs/alils/log_project.go rename to core/logs/alils/log_project.go diff --git a/pkg/core/logs/alils/log_store.go b/core/logs/alils/log_store.go similarity index 100% rename from pkg/core/logs/alils/log_store.go rename to core/logs/alils/log_store.go diff --git a/pkg/core/logs/alils/machine_group.go b/core/logs/alils/machine_group.go similarity index 100% rename from pkg/core/logs/alils/machine_group.go rename to core/logs/alils/machine_group.go diff --git a/pkg/core/logs/alils/request.go b/core/logs/alils/request.go similarity index 100% rename from pkg/core/logs/alils/request.go rename to core/logs/alils/request.go diff --git a/pkg/core/logs/alils/signature.go b/core/logs/alils/signature.go similarity index 100% rename from pkg/core/logs/alils/signature.go rename to core/logs/alils/signature.go diff --git a/pkg/core/logs/conn.go b/core/logs/conn.go similarity index 100% rename from pkg/core/logs/conn.go rename to core/logs/conn.go diff --git a/pkg/core/logs/conn_test.go b/core/logs/conn_test.go similarity index 100% rename from pkg/core/logs/conn_test.go rename to core/logs/conn_test.go diff --git a/pkg/core/logs/console.go b/core/logs/console.go similarity index 100% rename from pkg/core/logs/console.go rename to core/logs/console.go diff --git a/pkg/core/logs/console_test.go b/core/logs/console_test.go similarity index 100% rename from pkg/core/logs/console_test.go rename to core/logs/console_test.go diff --git a/pkg/core/logs/es/es.go b/core/logs/es/es.go similarity index 98% rename from pkg/core/logs/es/es.go rename to core/logs/es/es.go index a150c7b3..6175f253 100644 --- a/pkg/core/logs/es/es.go +++ b/core/logs/es/es.go @@ -12,7 +12,7 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // NewES returns a LoggerInterface diff --git a/pkg/core/logs/es/index.go b/core/logs/es/index.go similarity index 96% rename from pkg/core/logs/es/index.go rename to core/logs/es/index.go index 5b2d3d59..0dafef4c 100644 --- a/pkg/core/logs/es/index.go +++ b/core/logs/es/index.go @@ -17,7 +17,7 @@ package es import ( "fmt" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // IndexNaming generate the index name diff --git a/pkg/core/logs/es/index_test.go b/core/logs/es/index_test.go similarity index 95% rename from pkg/core/logs/es/index_test.go rename to core/logs/es/index_test.go index 25cfa5ed..03e7a911 100644 --- a/pkg/core/logs/es/index_test.go +++ b/core/logs/es/index_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) func TestDefaultIndexNaming_IndexName(t *testing.T) { diff --git a/pkg/core/logs/file.go b/core/logs/file.go similarity index 100% rename from pkg/core/logs/file.go rename to core/logs/file.go diff --git a/pkg/core/logs/file_test.go b/core/logs/file_test.go similarity index 100% rename from pkg/core/logs/file_test.go rename to core/logs/file_test.go diff --git a/pkg/core/logs/formatter.go b/core/logs/formatter.go similarity index 100% rename from pkg/core/logs/formatter.go rename to core/logs/formatter.go diff --git a/pkg/core/logs/formatter_test.go b/core/logs/formatter_test.go similarity index 100% rename from pkg/core/logs/formatter_test.go rename to core/logs/formatter_test.go diff --git a/pkg/core/logs/jianliao.go b/core/logs/jianliao.go similarity index 100% rename from pkg/core/logs/jianliao.go rename to core/logs/jianliao.go diff --git a/pkg/core/logs/jianliao_test.go b/core/logs/jianliao_test.go similarity index 100% rename from pkg/core/logs/jianliao_test.go rename to core/logs/jianliao_test.go diff --git a/pkg/core/logs/log.go b/core/logs/log.go similarity index 100% rename from pkg/core/logs/log.go rename to core/logs/log.go diff --git a/pkg/core/logs/log_msg.go b/core/logs/log_msg.go similarity index 100% rename from pkg/core/logs/log_msg.go rename to core/logs/log_msg.go diff --git a/pkg/core/logs/log_msg_test.go b/core/logs/log_msg_test.go similarity index 100% rename from pkg/core/logs/log_msg_test.go rename to core/logs/log_msg_test.go diff --git a/pkg/core/logs/log_test.go b/core/logs/log_test.go similarity index 100% rename from pkg/core/logs/log_test.go rename to core/logs/log_test.go diff --git a/pkg/core/logs/logger.go b/core/logs/logger.go similarity index 100% rename from pkg/core/logs/logger.go rename to core/logs/logger.go diff --git a/pkg/core/logs/logger_test.go b/core/logs/logger_test.go similarity index 100% rename from pkg/core/logs/logger_test.go rename to core/logs/logger_test.go diff --git a/pkg/core/logs/multifile.go b/core/logs/multifile.go similarity index 100% rename from pkg/core/logs/multifile.go rename to core/logs/multifile.go diff --git a/pkg/core/logs/multifile_test.go b/core/logs/multifile_test.go similarity index 100% rename from pkg/core/logs/multifile_test.go rename to core/logs/multifile_test.go diff --git a/pkg/core/logs/slack.go b/core/logs/slack.go similarity index 100% rename from pkg/core/logs/slack.go rename to core/logs/slack.go diff --git a/pkg/core/logs/smtp.go b/core/logs/smtp.go similarity index 100% rename from pkg/core/logs/smtp.go rename to core/logs/smtp.go diff --git a/pkg/core/logs/smtp_test.go b/core/logs/smtp_test.go similarity index 100% rename from pkg/core/logs/smtp_test.go rename to core/logs/smtp_test.go diff --git a/pkg/core/utils/caller.go b/core/utils/caller.go similarity index 100% rename from pkg/core/utils/caller.go rename to core/utils/caller.go diff --git a/pkg/core/utils/caller_test.go b/core/utils/caller_test.go similarity index 100% rename from pkg/core/utils/caller_test.go rename to core/utils/caller_test.go diff --git a/pkg/core/utils/debug.go b/core/utils/debug.go similarity index 100% rename from pkg/core/utils/debug.go rename to core/utils/debug.go diff --git a/pkg/core/utils/debug_test.go b/core/utils/debug_test.go similarity index 100% rename from pkg/core/utils/debug_test.go rename to core/utils/debug_test.go diff --git a/pkg/core/utils/file.go b/core/utils/file.go similarity index 100% rename from pkg/core/utils/file.go rename to core/utils/file.go diff --git a/pkg/core/utils/file_test.go b/core/utils/file_test.go similarity index 100% rename from pkg/core/utils/file_test.go rename to core/utils/file_test.go diff --git a/pkg/core/utils/kv.go b/core/utils/kv.go similarity index 100% rename from pkg/core/utils/kv.go rename to core/utils/kv.go diff --git a/pkg/core/utils/kv_test.go b/core/utils/kv_test.go similarity index 100% rename from pkg/core/utils/kv_test.go rename to core/utils/kv_test.go diff --git a/pkg/core/utils/mail.go b/core/utils/mail.go similarity index 100% rename from pkg/core/utils/mail.go rename to core/utils/mail.go diff --git a/pkg/core/utils/mail_test.go b/core/utils/mail_test.go similarity index 100% rename from pkg/core/utils/mail_test.go rename to core/utils/mail_test.go diff --git a/pkg/core/utils/pagination/doc.go b/core/utils/pagination/doc.go similarity index 96% rename from pkg/core/utils/pagination/doc.go rename to core/utils/pagination/doc.go index fb044ff9..b9c604b9 100644 --- a/pkg/core/utils/pagination/doc.go +++ b/core/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/astaxie/beego/pkg/core/utils/pagination" + import "github.com/astaxie/beego/core/utils/pagination" type PostsController struct { beego.Controller diff --git a/pkg/core/utils/pagination/paginator.go b/core/utils/pagination/paginator.go similarity index 100% rename from pkg/core/utils/pagination/paginator.go rename to core/utils/pagination/paginator.go diff --git a/pkg/core/utils/pagination/utils.go b/core/utils/pagination/utils.go similarity index 100% rename from pkg/core/utils/pagination/utils.go rename to core/utils/pagination/utils.go diff --git a/pkg/core/utils/rand.go b/core/utils/rand.go similarity index 100% rename from pkg/core/utils/rand.go rename to core/utils/rand.go diff --git a/pkg/core/utils/rand_test.go b/core/utils/rand_test.go similarity index 100% rename from pkg/core/utils/rand_test.go rename to core/utils/rand_test.go diff --git a/pkg/core/utils/safemap.go b/core/utils/safemap.go similarity index 100% rename from pkg/core/utils/safemap.go rename to core/utils/safemap.go diff --git a/pkg/core/utils/safemap_test.go b/core/utils/safemap_test.go similarity index 100% rename from pkg/core/utils/safemap_test.go rename to core/utils/safemap_test.go diff --git a/pkg/core/utils/slice.go b/core/utils/slice.go similarity index 100% rename from pkg/core/utils/slice.go rename to core/utils/slice.go diff --git a/pkg/core/utils/slice_test.go b/core/utils/slice_test.go similarity index 100% rename from pkg/core/utils/slice_test.go rename to core/utils/slice_test.go diff --git a/pkg/core/utils/testdata/grepe.test b/core/utils/testdata/grepe.test similarity index 100% rename from pkg/core/utils/testdata/grepe.test rename to core/utils/testdata/grepe.test diff --git a/pkg/core/utils/time.go b/core/utils/time.go similarity index 100% rename from pkg/core/utils/time.go rename to core/utils/time.go diff --git a/pkg/core/utils/utils.go b/core/utils/utils.go similarity index 100% rename from pkg/core/utils/utils.go rename to core/utils/utils.go diff --git a/pkg/core/utils/utils_test.go b/core/utils/utils_test.go similarity index 100% rename from pkg/core/utils/utils_test.go rename to core/utils/utils_test.go diff --git a/pkg/core/validation/README.md b/core/validation/README.md similarity index 100% rename from pkg/core/validation/README.md rename to core/validation/README.md diff --git a/pkg/core/validation/util.go b/core/validation/util.go similarity index 100% rename from pkg/core/validation/util.go rename to core/validation/util.go diff --git a/pkg/core/validation/util_test.go b/core/validation/util_test.go similarity index 100% rename from pkg/core/validation/util_test.go rename to core/validation/util_test.go diff --git a/pkg/core/validation/validation.go b/core/validation/validation.go similarity index 100% rename from pkg/core/validation/validation.go rename to core/validation/validation.go diff --git a/pkg/core/validation/validation_test.go b/core/validation/validation_test.go similarity index 100% rename from pkg/core/validation/validation_test.go rename to core/validation/validation_test.go diff --git a/pkg/core/validation/validators.go b/core/validation/validators.go similarity index 99% rename from pkg/core/validation/validators.go rename to core/validation/validators.go index 1652ee2c..ec422d86 100644 --- a/pkg/core/validation/validators.go +++ b/core/validation/validators.go @@ -23,7 +23,7 @@ import ( "time" "unicode/utf8" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/pkg/doc.go b/doc.go similarity index 97% rename from pkg/doc.go rename to doc.go index 2d9c2bfe..6975885a 100644 --- a/pkg/doc.go +++ b/doc.go @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pkg +package beego diff --git a/githook/pre-commit b/githook/pre-commit deleted file mode 100755 index 95b1009b..00000000 --- a/githook/pre-commit +++ /dev/null @@ -1,7 +0,0 @@ - -goimports -w -format-only pkg -goimports -w -format-only examples - -ineffassign . - -staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./pkg \ No newline at end of file diff --git a/pkg/adapter/logs/alils/alils.go b/pkg/adapter/logs/alils/alils.go deleted file mode 100644 index 5abbc29f..00000000 --- a/pkg/adapter/logs/alils/alils.go +++ /dev/null @@ -1,5 +0,0 @@ -package alils - -import ( - _ "github.com/astaxie/beego/pkg/core/logs/alils" -) diff --git a/pkg/adapter/logs/es/es.go b/pkg/adapter/logs/es/es.go deleted file mode 100644 index e0759485..00000000 --- a/pkg/adapter/logs/es/es.go +++ /dev/null @@ -1,5 +0,0 @@ -package es - -import ( - _ "github.com/astaxie/beego/pkg/core/logs/es" -) diff --git a/pkg/server/web/LICENSE b/server/web/LICENSE similarity index 100% rename from pkg/server/web/LICENSE rename to server/web/LICENSE diff --git a/pkg/server/web/admin.go b/server/web/admin.go similarity index 98% rename from pkg/server/web/admin.go rename to server/web/admin.go index 4cac58ba..a1c47e0c 100644 --- a/pkg/server/web/admin.go +++ b/server/web/admin.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" ) // BeeAdminApp is the default adminApp used by admin module. diff --git a/pkg/server/web/admin_controller.go b/server/web/admin_controller.go similarity index 99% rename from pkg/server/web/admin_controller.go rename to server/web/admin_controller.go index 575362d7..2998c8d4 100644 --- a/pkg/server/web/admin_controller.go +++ b/server/web/admin_controller.go @@ -24,7 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/astaxie/beego/pkg/core/governor" + "github.com/astaxie/beego/core/governor" ) type adminController struct { diff --git a/pkg/server/web/admin_test.go b/server/web/admin_test.go similarity index 99% rename from pkg/server/web/admin_test.go rename to server/web/admin_test.go index c33bbf2f..5ef57323 100644 --- a/pkg/server/web/admin_test.go +++ b/server/web/admin_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/core/governor" + "github.com/astaxie/beego/core/governor" ) type SampleDatabaseCheck struct { diff --git a/pkg/server/web/adminui.go b/server/web/adminui.go similarity index 100% rename from pkg/server/web/adminui.go rename to server/web/adminui.go diff --git a/pkg/server/web/beego.go b/server/web/beego.go similarity index 100% rename from pkg/server/web/beego.go rename to server/web/beego.go diff --git a/pkg/server/web/captcha/LICENSE b/server/web/captcha/LICENSE similarity index 100% rename from pkg/server/web/captcha/LICENSE rename to server/web/captcha/LICENSE diff --git a/pkg/server/web/captcha/README.md b/server/web/captcha/README.md similarity index 100% rename from pkg/server/web/captcha/README.md rename to server/web/captcha/README.md diff --git a/pkg/server/web/captcha/captcha.go b/server/web/captcha/captcha.go similarity index 97% rename from pkg/server/web/captcha/captcha.go rename to server/web/captcha/captcha.go index 876e6074..8ce832f7 100644 --- a/pkg/server/web/captcha/captcha.go +++ b/server/web/captcha/captcha.go @@ -67,11 +67,11 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - "github.com/astaxie/beego/pkg/core/utils" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/core/utils" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) var ( diff --git a/pkg/server/web/captcha/image.go b/server/web/captcha/image.go similarity index 100% rename from pkg/server/web/captcha/image.go rename to server/web/captcha/image.go diff --git a/pkg/server/web/captcha/image_test.go b/server/web/captcha/image_test.go similarity index 96% rename from pkg/server/web/captcha/image_test.go rename to server/web/captcha/image_test.go index a6b82f56..4b4518a1 100644 --- a/pkg/server/web/captcha/image_test.go +++ b/server/web/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) type byteCounter struct { diff --git a/pkg/server/web/captcha/siprng.go b/server/web/captcha/siprng.go similarity index 100% rename from pkg/server/web/captcha/siprng.go rename to server/web/captcha/siprng.go diff --git a/pkg/server/web/captcha/siprng_test.go b/server/web/captcha/siprng_test.go similarity index 100% rename from pkg/server/web/captcha/siprng_test.go rename to server/web/captcha/siprng_test.go diff --git a/pkg/server/web/config.go b/server/web/config.go similarity index 97% rename from pkg/server/web/config.go rename to server/web/config.go index 9e2a2885..10dc9c97 100644 --- a/pkg/server/web/config.go +++ b/server/web/config.go @@ -24,13 +24,13 @@ import ( "runtime" "strings" - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/core/config" - "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego" + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" + "github.com/astaxie/beego/server/web/session" - "github.com/astaxie/beego/pkg/core/utils" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/core/utils" + "github.com/astaxie/beego/server/web/context" ) // Config is the main struct for BConfig @@ -210,7 +210,7 @@ func newBConfig() *Config { AppName: "beego", RunMode: PROD, RouterCaseSensitive: true, - ServerName: "beegoServer:" + pkg.VERSION, + ServerName: "beegoServer:" + beego.VERSION, RecoverPanic: true, CopyRequestBody: false, diff --git a/pkg/server/web/config_test.go b/server/web/config_test.go similarity index 98% rename from pkg/server/web/config_test.go rename to server/web/config_test.go index ce4a4492..88fa8b8c 100644 --- a/pkg/server/web/config_test.go +++ b/server/web/config_test.go @@ -19,7 +19,7 @@ import ( "reflect" "testing" - beeJson "github.com/astaxie/beego/pkg/core/config/json" + beeJson "github.com/astaxie/beego/core/config/json" ) func TestDefaults(t *testing.T) { diff --git a/pkg/server/web/context/acceptencoder.go b/server/web/context/acceptencoder.go similarity index 100% rename from pkg/server/web/context/acceptencoder.go rename to server/web/context/acceptencoder.go diff --git a/pkg/server/web/context/acceptencoder_test.go b/server/web/context/acceptencoder_test.go similarity index 100% rename from pkg/server/web/context/acceptencoder_test.go rename to server/web/context/acceptencoder_test.go diff --git a/pkg/server/web/context/context.go b/server/web/context/context.go similarity index 99% rename from pkg/server/web/context/context.go rename to server/web/context/context.go index 1a6c00a8..53ed3d01 100644 --- a/pkg/server/web/context/context.go +++ b/server/web/context/context.go @@ -35,7 +35,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // Commonly used mime-types diff --git a/pkg/server/web/context/context_test.go b/server/web/context/context_test.go similarity index 100% rename from pkg/server/web/context/context_test.go rename to server/web/context/context_test.go diff --git a/pkg/server/web/context/input.go b/server/web/context/input.go similarity index 99% rename from pkg/server/web/context/input.go rename to server/web/context/input.go index 746822aa..641e15cc 100644 --- a/pkg/server/web/context/input.go +++ b/server/web/context/input.go @@ -29,7 +29,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) // Regexes for checking the accept headers diff --git a/pkg/server/web/context/input_test.go b/server/web/context/input_test.go similarity index 100% rename from pkg/server/web/context/input_test.go rename to server/web/context/input_test.go diff --git a/pkg/server/web/context/output.go b/server/web/context/output.go similarity index 100% rename from pkg/server/web/context/output.go rename to server/web/context/output.go diff --git a/pkg/server/web/context/param/conv.go b/server/web/context/param/conv.go similarity index 95% rename from pkg/server/web/context/param/conv.go rename to server/web/context/param/conv.go index 73e83e30..fe3388b6 100644 --- a/pkg/server/web/context/param/conv.go +++ b/server/web/context/param/conv.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - "github.com/astaxie/beego/pkg/core/logs" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/core/logs" + beecontext "github.com/astaxie/beego/server/web/context" ) // ConvertParams converts http method params to values that will be passed to the method controller as arguments diff --git a/pkg/server/web/context/param/methodparams.go b/server/web/context/param/methodparams.go similarity index 100% rename from pkg/server/web/context/param/methodparams.go rename to server/web/context/param/methodparams.go diff --git a/pkg/server/web/context/param/options.go b/server/web/context/param/options.go similarity index 100% rename from pkg/server/web/context/param/options.go rename to server/web/context/param/options.go diff --git a/pkg/server/web/context/param/parsers.go b/server/web/context/param/parsers.go similarity index 100% rename from pkg/server/web/context/param/parsers.go rename to server/web/context/param/parsers.go diff --git a/pkg/server/web/context/param/parsers_test.go b/server/web/context/param/parsers_test.go similarity index 100% rename from pkg/server/web/context/param/parsers_test.go rename to server/web/context/param/parsers_test.go diff --git a/pkg/server/web/context/renderer.go b/server/web/context/renderer.go similarity index 100% rename from pkg/server/web/context/renderer.go rename to server/web/context/renderer.go diff --git a/pkg/server/web/context/response.go b/server/web/context/response.go similarity index 99% rename from pkg/server/web/context/response.go rename to server/web/context/response.go index d80cfe89..7bd9a7e8 100644 --- a/pkg/server/web/context/response.go +++ b/server/web/context/response.go @@ -1,9 +1,8 @@ package context import ( - "strconv" - "net/http" + "strconv" ) const ( diff --git a/pkg/server/web/controller.go b/server/web/controller.go similarity index 99% rename from pkg/server/web/controller.go rename to server/web/controller.go index a8e2ae63..3a1b9837 100644 --- a/pkg/server/web/controller.go +++ b/server/web/controller.go @@ -28,10 +28,10 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" - "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/context/param" + "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/context/param" ) var ( diff --git a/pkg/server/web/controller_test.go b/server/web/controller_test.go similarity index 98% rename from pkg/server/web/controller_test.go rename to server/web/controller_test.go index 46da3629..0b711e0d 100644 --- a/pkg/server/web/controller_test.go +++ b/server/web/controller_test.go @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) func TestGetInt(t *testing.T) { diff --git a/pkg/server/web/doc.go b/server/web/doc.go similarity index 91% rename from pkg/server/web/doc.go rename to server/web/doc.go index 0ab10bfd..a32bc576 100644 --- a/pkg/server/web/doc.go +++ b/server/web/doc.go @@ -6,7 +6,7 @@ It is used for rapid development of RESTful APIs, web apps and backend services beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. package main - import "github.com/astaxie/beego/pkg" + import "github.com/astaxie/beego" func main() { beego.Run() diff --git a/pkg/server/web/error.go b/server/web/error.go similarity index 98% rename from pkg/server/web/error.go rename to server/web/error.go index d0a8d778..b5ef1d2d 100644 --- a/pkg/server/web/error.go +++ b/server/web/error.go @@ -23,10 +23,10 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego" + "github.com/astaxie/beego/core/utils" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) const ( @@ -92,7 +92,7 @@ func showErr(err interface{}, ctx *context.Context, stack string) { "RequestURL": ctx.Input.URI(), "RemoteAddr": ctx.Input.IP(), "Stack": stack, - "BeegoVersion": pkg.VERSION, + "BeegoVersion": beego.VERSION, "GoVersion": runtime.Version(), } t.Execute(ctx.ResponseWriter, data) @@ -379,7 +379,7 @@ func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errCont t, _ := template.New("beegoerrortemp").Parse(errtpl) data := M{ "Title": http.StatusText(errCode), - "BeegoVersion": pkg.VERSION, + "BeegoVersion": beego.VERSION, "Content": template.HTML(errContent), } t.Execute(rw, data) diff --git a/pkg/server/web/error_test.go b/server/web/error_test.go similarity index 100% rename from pkg/server/web/error_test.go rename to server/web/error_test.go diff --git a/pkg/server/web/filter.go b/server/web/filter.go similarity index 98% rename from pkg/server/web/filter.go rename to server/web/filter.go index 9aab48d6..967de8c9 100644 --- a/pkg/server/web/filter.go +++ b/server/web/filter.go @@ -17,7 +17,7 @@ package web import ( "strings" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // FilterChain is different from pure FilterFunc diff --git a/pkg/server/web/filter/apiauth/apiauth.go b/server/web/filter/apiauth/apiauth.go similarity index 97% rename from pkg/server/web/filter/apiauth/apiauth.go rename to server/web/filter/apiauth/apiauth.go index 8944db63..58153f1d 100644 --- a/pkg/server/web/filter/apiauth/apiauth.go +++ b/server/web/filter/apiauth/apiauth.go @@ -65,8 +65,8 @@ import ( "sort" "time" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) // AppIDToAppSecret gets appsecret through appid diff --git a/pkg/server/web/filter/apiauth/apiauth_test.go b/server/web/filter/apiauth/apiauth_test.go similarity index 100% rename from pkg/server/web/filter/apiauth/apiauth_test.go rename to server/web/filter/apiauth/apiauth_test.go diff --git a/pkg/server/web/filter/auth/basic.go b/server/web/filter/auth/basic.go similarity index 97% rename from pkg/server/web/filter/auth/basic.go rename to server/web/filter/auth/basic.go index 209cd97d..ee6af6c3 100644 --- a/pkg/server/web/filter/auth/basic.go +++ b/server/web/filter/auth/basic.go @@ -40,8 +40,8 @@ import ( "net/http" "strings" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) var defaultRealm = "Authorization Required" diff --git a/pkg/server/web/filter/authz/authz.go b/server/web/filter/authz/authz.go similarity index 96% rename from pkg/server/web/filter/authz/authz.go rename to server/web/filter/authz/authz.go index a3a8dca6..857c52f2 100644 --- a/pkg/server/web/filter/authz/authz.go +++ b/server/web/filter/authz/authz.go @@ -44,8 +44,8 @@ import ( "github.com/casbin/casbin" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) // NewAuthorizer returns the authorizer. diff --git a/pkg/server/web/filter/authz/authz_model.conf b/server/web/filter/authz/authz_model.conf similarity index 100% rename from pkg/server/web/filter/authz/authz_model.conf rename to server/web/filter/authz/authz_model.conf diff --git a/pkg/server/web/filter/authz/authz_policy.csv b/server/web/filter/authz/authz_policy.csv similarity index 100% rename from pkg/server/web/filter/authz/authz_policy.csv rename to server/web/filter/authz/authz_policy.csv diff --git a/pkg/server/web/filter/authz/authz_test.go b/server/web/filter/authz/authz_test.go similarity index 96% rename from pkg/server/web/filter/authz/authz_test.go rename to server/web/filter/authz/authz_test.go index e50596b4..c0d0dde5 100644 --- a/pkg/server/web/filter/authz/authz_test.go +++ b/server/web/filter/authz/authz_test.go @@ -21,9 +21,9 @@ import ( "github.com/casbin/casbin" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/filter/auth" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/filter/auth" ) func testRequest(t *testing.T, handler *web.ControllerRegister, user string, path string, method string, code int) { diff --git a/pkg/server/web/filter/cors/cors.go b/server/web/filter/cors/cors.go similarity index 98% rename from pkg/server/web/filter/cors/cors.go rename to server/web/filter/cors/cors.go index 800eeded..3a6905ea 100644 --- a/pkg/server/web/filter/cors/cors.go +++ b/server/web/filter/cors/cors.go @@ -42,8 +42,8 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) const ( diff --git a/pkg/server/web/filter/cors/cors_test.go b/server/web/filter/cors/cors_test.go similarity index 98% rename from pkg/server/web/filter/cors/cors_test.go rename to server/web/filter/cors/cors_test.go index 60659fdd..7649de25 100644 --- a/pkg/server/web/filter/cors/cors_test.go +++ b/server/web/filter/cors/cors_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header diff --git a/pkg/server/web/filter/opentracing/filter.go b/server/web/filter/opentracing/filter.go similarity index 96% rename from pkg/server/web/filter/opentracing/filter.go rename to server/web/filter/opentracing/filter.go index dd5663f9..c2defa18 100644 --- a/pkg/server/web/filter/opentracing/filter.go +++ b/server/web/filter/opentracing/filter.go @@ -17,12 +17,11 @@ package opentracing import ( "context" + "github.com/astaxie/beego/server/web" + beegoCtx "github.com/astaxie/beego/server/web/context" logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" - - "github.com/astaxie/beego/pkg/server/web" - beegoCtx "github.com/astaxie/beego/pkg/server/web/context" ) // FilterChainBuilder provides an extension point that we can support more configurations if necessary diff --git a/pkg/server/web/filter/opentracing/filter_test.go b/server/web/filter/opentracing/filter_test.go similarity index 96% rename from pkg/server/web/filter/opentracing/filter_test.go rename to server/web/filter/opentracing/filter_test.go index 04f44324..d7222c37 100644 --- a/pkg/server/web/filter/opentracing/filter_test.go +++ b/server/web/filter/opentracing/filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/server/web/filter/prometheus/filter.go b/server/web/filter/prometheus/filter.go similarity index 84% rename from pkg/server/web/filter/prometheus/filter.go rename to server/web/filter/prometheus/filter.go index eb5b0b78..7daabd5a 100644 --- a/pkg/server/web/filter/prometheus/filter.go +++ b/server/web/filter/prometheus/filter.go @@ -21,9 +21,9 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/server/web" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego" + "github.com/astaxie/beego/server/web" + "github.com/astaxie/beego/server/web/context" ) // FilterChainBuilder is an extension point, @@ -64,13 +64,13 @@ func registerBuildInfo() { Help: "The building information", ConstLabels: map[string]string{ "appname": web.BConfig.AppName, - "build_version": pkg.BuildVersion, - "build_revision": pkg.BuildGitRevision, - "build_status": pkg.BuildStatus, - "build_tag": pkg.BuildTag, - "build_time": strings.Replace(pkg.BuildTime, "--", " ", 1), - "go_version": pkg.GoVersion, - "git_branch": pkg.GitBranch, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go similarity index 95% rename from pkg/server/web/filter/prometheus/filter_test.go rename to server/web/filter/prometheus/filter_test.go index 08887839..cb133a64 100644 --- a/pkg/server/web/filter/prometheus/filter_test.go +++ b/server/web/filter/prometheus/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) func TestFilterChain(t *testing.T) { diff --git a/pkg/server/web/filter_chain_test.go b/server/web/filter_chain_test.go similarity index 95% rename from pkg/server/web/filter_chain_test.go rename to server/web/filter_chain_test.go index 44d5f71e..e175ab29 100644 --- a/pkg/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) func TestControllerRegister_InsertFilterChain(t *testing.T) { diff --git a/pkg/server/web/filter_test.go b/server/web/filter_test.go similarity index 97% rename from pkg/server/web/filter_test.go rename to server/web/filter_test.go index eea50534..11f575d6 100644 --- a/pkg/server/web/filter_test.go +++ b/server/web/filter_test.go @@ -19,7 +19,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) var FilterUser = func(ctx *context.Context) { diff --git a/pkg/server/web/flash.go b/server/web/flash.go similarity index 100% rename from pkg/server/web/flash.go rename to server/web/flash.go diff --git a/pkg/server/web/flash_test.go b/server/web/flash_test.go similarity index 100% rename from pkg/server/web/flash_test.go rename to server/web/flash_test.go diff --git a/pkg/server/web/fs.go b/server/web/fs.go similarity index 100% rename from pkg/server/web/fs.go rename to server/web/fs.go diff --git a/pkg/server/web/grace/grace.go b/server/web/grace/grace.go similarity index 100% rename from pkg/server/web/grace/grace.go rename to server/web/grace/grace.go diff --git a/pkg/server/web/grace/server.go b/server/web/grace/server.go similarity index 100% rename from pkg/server/web/grace/server.go rename to server/web/grace/server.go diff --git a/pkg/server/web/hooks.go b/server/web/hooks.go similarity index 95% rename from pkg/server/web/hooks.go rename to server/web/hooks.go index d7c6cf16..090e45d3 100644 --- a/pkg/server/web/hooks.go +++ b/server/web/hooks.go @@ -7,9 +7,9 @@ import ( "net/http" "path/filepath" - "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/core/logs" + "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/session" ) // register MIME type with content type diff --git a/pkg/server/web/mime.go b/server/web/mime.go similarity index 100% rename from pkg/server/web/mime.go rename to server/web/mime.go diff --git a/pkg/server/web/namespace.go b/server/web/namespace.go similarity index 91% rename from pkg/server/web/namespace.go rename to server/web/namespace.go index a792aa60..58afb6c7 100644 --- a/pkg/server/web/namespace.go +++ b/server/web/namespace.go @@ -18,7 +18,7 @@ import ( "net/http" "strings" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + beecontext "github.com/astaxie/beego/server/web/context" ) type namespaceCond func(*beecontext.Context) bool @@ -97,91 +97,91 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { } // Router same as beego.Rourer -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Router +// refer: https://godoc.org/github.com/astaxie/beego#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { n.handlers.Add(rootpath, c, mappingMethods...) return n } // AutoRouter same as beego.AutoRouter -// refer: https://godoc.org/github.com/astaxie/beego/pkg#AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { n.handlers.AddAuto(c) return n } // AutoPrefix same as beego.AutoPrefix -// refer: https://godoc.org/github.com/astaxie/beego/pkg#AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { n.handlers.AddAutoPrefix(prefix, c) return n } // Get same as beego.Get -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Get +// refer: https://godoc.org/github.com/astaxie/beego#Get func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { n.handlers.Get(rootpath, f) return n } // Post same as beego.Post -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Post +// refer: https://godoc.org/github.com/astaxie/beego#Post func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { n.handlers.Post(rootpath, f) return n } // Delete same as beego.Delete -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { n.handlers.Delete(rootpath, f) return n } // Put same as beego.Put -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Put +// refer: https://godoc.org/github.com/astaxie/beego#Put func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { n.handlers.Put(rootpath, f) return n } // Head same as beego.Head -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Head +// refer: https://godoc.org/github.com/astaxie/beego#Head func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { n.handlers.Head(rootpath, f) return n } // Options same as beego.Options -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Options +// refer: https://godoc.org/github.com/astaxie/beego#Options func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { n.handlers.Options(rootpath, f) return n } // Patch same as beego.Patch -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { n.handlers.Patch(rootpath, f) return n } // Any same as beego.Any -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Any +// refer: https://godoc.org/github.com/astaxie/beego#Any func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { n.handlers.Any(rootpath, f) return n } // Handler same as beego.Handler -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { n.handlers.Handler(rootpath, h) return n } // Include add include class -// refer: https://godoc.org/github.com/astaxie/beego/pkg#Include +// refer: https://godoc.org/github.com/astaxie/beego#Include func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { n.handlers.Include(cList...) return n diff --git a/pkg/server/web/namespace_test.go b/server/web/namespace_test.go similarity index 98% rename from pkg/server/web/namespace_test.go rename to server/web/namespace_test.go index 39d60041..a6f87bba 100644 --- a/pkg/server/web/namespace_test.go +++ b/server/web/namespace_test.go @@ -20,7 +20,7 @@ import ( "strconv" "testing" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) func TestNamespaceGet(t *testing.T) { diff --git a/pkg/server/web/pagination/controller.go b/server/web/pagination/controller.go similarity index 90% rename from pkg/server/web/pagination/controller.go rename to server/web/pagination/controller.go index 675437f8..f6b2f73d 100644 --- a/pkg/server/web/pagination/controller.go +++ b/server/web/pagination/controller.go @@ -15,8 +15,8 @@ package pagination import ( - "github.com/astaxie/beego/pkg/core/utils/pagination" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/core/utils/pagination" + "github.com/astaxie/beego/server/web/context" ) // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). diff --git a/pkg/server/web/parser.go b/server/web/parser.go similarity index 98% rename from pkg/server/web/parser.go rename to server/web/parser.go index 9dfeca56..1a8a33df 100644 --- a/pkg/server/web/parser.go +++ b/server/web/parser.go @@ -31,17 +31,17 @@ import ( "golang.org/x/tools/go/packages" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - "github.com/astaxie/beego/pkg/core/utils" - "github.com/astaxie/beego/pkg/server/web/context/param" + "github.com/astaxie/beego/core/utils" + "github.com/astaxie/beego/server/web/context/param" ) var globalRouterTemplate = `package {{.routersDir}} import ( - "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/server/web/context/param"{{.globalimport}} + "github.com/astaxie/beego" + "github.com/astaxie/beego/server/web/context/param"{{.globalimport}} ) func init() { diff --git a/pkg/server/web/policy.go b/server/web/policy.go similarity index 98% rename from pkg/server/web/policy.go rename to server/web/policy.go index 2099f99d..14673422 100644 --- a/pkg/server/web/policy.go +++ b/server/web/policy.go @@ -17,7 +17,7 @@ package web import ( "strings" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. diff --git a/pkg/server/web/router.go b/server/web/router.go similarity index 99% rename from pkg/server/web/router.go rename to server/web/router.go index 07007a2b..0f383f99 100644 --- a/pkg/server/web/router.go +++ b/server/web/router.go @@ -25,11 +25,11 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - "github.com/astaxie/beego/pkg/core/utils" - beecontext "github.com/astaxie/beego/pkg/server/web/context" - "github.com/astaxie/beego/pkg/server/web/context/param" + "github.com/astaxie/beego/core/utils" + beecontext "github.com/astaxie/beego/server/web/context" + "github.com/astaxie/beego/server/web/context/param" ) // default filter execution points diff --git a/pkg/server/web/router_test.go b/server/web/router_test.go similarity index 99% rename from pkg/server/web/router_test.go rename to server/web/router_test.go index 2bc7990c..59ccd1fc 100644 --- a/pkg/server/web/router_test.go +++ b/server/web/router_test.go @@ -21,9 +21,9 @@ import ( "strings" "testing" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) type TestController struct { diff --git a/pkg/server/web/server.go b/server/web/server.go similarity index 99% rename from pkg/server/web/server.go rename to server/web/server.go index 75523d0c..f289fd9b 100644 --- a/pkg/server/web/server.go +++ b/server/web/server.go @@ -31,11 +31,11 @@ import ( "golang.org/x/crypto/acme/autocert" - "github.com/astaxie/beego/pkg/core/logs" - beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/core/logs" + beecontext "github.com/astaxie/beego/server/web/context" - "github.com/astaxie/beego/pkg/core/utils" - "github.com/astaxie/beego/pkg/server/web/grace" + "github.com/astaxie/beego/core/utils" + "github.com/astaxie/beego/server/web/grace" ) var ( diff --git a/pkg/server/web/server_test.go b/server/web/server_test.go similarity index 100% rename from pkg/server/web/server_test.go rename to server/web/server_test.go diff --git a/pkg/server/web/session/README.md b/server/web/session/README.md similarity index 100% rename from pkg/server/web/session/README.md rename to server/web/session/README.md diff --git a/pkg/server/web/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go similarity index 99% rename from pkg/server/web/session/couchbase/sess_couchbase.go rename to server/web/session/couchbase/sess_couchbase.go index dc616539..ddd46401 100644 --- a/pkg/server/web/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -40,7 +40,7 @@ import ( couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) var couchbpder = &Provider{} diff --git a/pkg/server/web/session/ledis/ledis_session.go b/server/web/session/ledis/ledis_session.go similarity index 98% rename from pkg/server/web/session/ledis/ledis_session.go rename to server/web/session/ledis/ledis_session.go index e4061a39..a920ff7c 100644 --- a/pkg/server/web/session/ledis/ledis_session.go +++ b/server/web/session/ledis/ledis_session.go @@ -11,7 +11,7 @@ import ( "github.com/ledisdb/ledisdb/config" "github.com/ledisdb/ledisdb/ledis" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) var ( diff --git a/pkg/server/web/session/memcache/sess_memcache.go b/server/web/session/memcache/sess_memcache.go similarity index 99% rename from pkg/server/web/session/memcache/sess_memcache.go rename to server/web/session/memcache/sess_memcache.go index 731df01e..168116ef 100644 --- a/pkg/server/web/session/memcache/sess_memcache.go +++ b/server/web/session/memcache/sess_memcache.go @@ -38,7 +38,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" "github.com/bradfitz/gomemcache/memcache" ) diff --git a/pkg/server/web/session/mysql/sess_mysql.go b/server/web/session/mysql/sess_mysql.go similarity index 99% rename from pkg/server/web/session/mysql/sess_mysql.go rename to server/web/session/mysql/sess_mysql.go index d9b0f6b4..89da361d 100644 --- a/pkg/server/web/session/mysql/sess_mysql.go +++ b/server/web/session/mysql/sess_mysql.go @@ -47,7 +47,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" // import mysql driver _ "github.com/go-sql-driver/mysql" ) diff --git a/pkg/server/web/session/postgres/sess_postgresql.go b/server/web/session/postgres/sess_postgresql.go similarity index 99% rename from pkg/server/web/session/postgres/sess_postgresql.go rename to server/web/session/postgres/sess_postgresql.go index c5f8f3fa..a83ac083 100644 --- a/pkg/server/web/session/postgres/sess_postgresql.go +++ b/server/web/session/postgres/sess_postgresql.go @@ -57,7 +57,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" // import postgresql Driver _ "github.com/lib/pq" ) diff --git a/pkg/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go similarity index 99% rename from pkg/server/web/session/redis/sess_redis.go rename to server/web/session/redis/sess_redis.go index 3b25ae65..6ee28e2f 100644 --- a/pkg/server/web/session/redis/sess_redis.go +++ b/server/web/session/redis/sess_redis.go @@ -40,9 +40,9 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/server/web/session" - "github.com/go-redis/redis/v7" + + "github.com/astaxie/beego/server/web/session" ) var redispder = &Provider{} diff --git a/pkg/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go similarity index 97% rename from pkg/server/web/session/redis/sess_redis_test.go rename to server/web/session/redis/sess_redis_test.go index 7e63361a..19c8c025 100644 --- a/pkg/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) func TestRedis(t *testing.T) { diff --git a/pkg/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go similarity index 99% rename from pkg/server/web/session/redis_cluster/redis_cluster.go rename to server/web/session/redis_cluster/redis_cluster.go index 635ba915..17653d56 100644 --- a/pkg/server/web/session/redis_cluster/redis_cluster.go +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -40,7 +40,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" rediss "github.com/go-redis/redis/v7" ) diff --git a/pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go similarity index 99% rename from pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go rename to server/web/session/redis_sentinel/sess_redis_sentinel.go index 4b21242d..d68b8767 100644 --- a/pkg/server/web/session/redis_sentinel/sess_redis_sentinel.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -40,8 +40,9 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/server/web/session" "github.com/go-redis/redis/v7" + + "github.com/astaxie/beego/server/web/session" ) var redispder = &Provider{} diff --git a/pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go similarity index 97% rename from pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go rename to server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 8d565bf9..e4822d11 100644 --- a/pkg/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) func TestRedisSentinel(t *testing.T) { diff --git a/pkg/server/web/session/sess_cookie.go b/server/web/session/sess_cookie.go similarity index 100% rename from pkg/server/web/session/sess_cookie.go rename to server/web/session/sess_cookie.go diff --git a/pkg/server/web/session/sess_cookie_test.go b/server/web/session/sess_cookie_test.go similarity index 100% rename from pkg/server/web/session/sess_cookie_test.go rename to server/web/session/sess_cookie_test.go diff --git a/pkg/server/web/session/sess_file.go b/server/web/session/sess_file.go similarity index 100% rename from pkg/server/web/session/sess_file.go rename to server/web/session/sess_file.go diff --git a/pkg/server/web/session/sess_file_test.go b/server/web/session/sess_file_test.go similarity index 100% rename from pkg/server/web/session/sess_file_test.go rename to server/web/session/sess_file_test.go diff --git a/pkg/server/web/session/sess_mem.go b/server/web/session/sess_mem.go similarity index 100% rename from pkg/server/web/session/sess_mem.go rename to server/web/session/sess_mem.go diff --git a/pkg/server/web/session/sess_mem_test.go b/server/web/session/sess_mem_test.go similarity index 100% rename from pkg/server/web/session/sess_mem_test.go rename to server/web/session/sess_mem_test.go diff --git a/pkg/server/web/session/sess_test.go b/server/web/session/sess_test.go similarity index 100% rename from pkg/server/web/session/sess_test.go rename to server/web/session/sess_test.go diff --git a/pkg/server/web/session/sess_utils.go b/server/web/session/sess_utils.go similarity index 99% rename from pkg/server/web/session/sess_utils.go rename to server/web/session/sess_utils.go index 5f97d1a4..8a031dd5 100644 --- a/pkg/server/web/session/sess_utils.go +++ b/server/web/session/sess_utils.go @@ -29,7 +29,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) func init() { diff --git a/pkg/server/web/session/session.go b/server/web/session/session.go similarity index 100% rename from pkg/server/web/session/session.go rename to server/web/session/session.go diff --git a/pkg/server/web/session/ssdb/sess_ssdb.go b/server/web/session/ssdb/sess_ssdb.go similarity index 98% rename from pkg/server/web/session/ssdb/sess_ssdb.go rename to server/web/session/ssdb/sess_ssdb.go index d15f2171..9d1230d0 100644 --- a/pkg/server/web/session/ssdb/sess_ssdb.go +++ b/server/web/session/ssdb/sess_ssdb.go @@ -10,7 +10,7 @@ import ( "github.com/ssdb/gossdb/ssdb" - "github.com/astaxie/beego/pkg/server/web/session" + "github.com/astaxie/beego/server/web/session" ) var ssdbProvider = &Provider{} diff --git a/pkg/server/web/staticfile.go b/server/web/staticfile.go similarity index 98% rename from pkg/server/web/staticfile.go rename to server/web/staticfile.go index 4aabbc60..aa3f35d8 100644 --- a/pkg/server/web/staticfile.go +++ b/server/web/staticfile.go @@ -26,10 +26,10 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/logs" + "github.com/astaxie/beego/core/logs" lru "github.com/hashicorp/golang-lru" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) var errNotStaticRequest = errors.New("request not a static file request") diff --git a/pkg/server/web/staticfile_test.go b/server/web/staticfile_test.go similarity index 100% rename from pkg/server/web/staticfile_test.go rename to server/web/staticfile_test.go diff --git a/pkg/server/web/statistics.go b/server/web/statistics.go similarity index 99% rename from pkg/server/web/statistics.go rename to server/web/statistics.go index 7d5d5800..98f85e96 100644 --- a/pkg/server/web/statistics.go +++ b/server/web/statistics.go @@ -19,7 +19,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" ) // Statistics struct diff --git a/pkg/server/web/statistics_test.go b/server/web/statistics_test.go similarity index 100% rename from pkg/server/web/statistics_test.go rename to server/web/statistics_test.go diff --git a/pkg/server/web/swagger/swagger.go b/server/web/swagger/swagger.go similarity index 100% rename from pkg/server/web/swagger/swagger.go rename to server/web/swagger/swagger.go diff --git a/pkg/server/web/template.go b/server/web/template.go similarity index 99% rename from pkg/server/web/template.go rename to server/web/template.go index 4a07b0da..d582dcda 100644 --- a/pkg/server/web/template.go +++ b/server/web/template.go @@ -27,8 +27,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/core/logs" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/logs" + "github.com/astaxie/beego/core/utils" ) var ( diff --git a/pkg/server/web/template_test.go b/server/web/template_test.go similarity index 100% rename from pkg/server/web/template_test.go rename to server/web/template_test.go diff --git a/pkg/server/web/templatefunc.go b/server/web/templatefunc.go similarity index 100% rename from pkg/server/web/templatefunc.go rename to server/web/templatefunc.go diff --git a/pkg/server/web/templatefunc_test.go b/server/web/templatefunc_test.go similarity index 100% rename from pkg/server/web/templatefunc_test.go rename to server/web/templatefunc_test.go diff --git a/pkg/server/web/tree.go b/server/web/tree.go similarity index 99% rename from pkg/server/web/tree.go rename to server/web/tree.go index 9038c010..fc5a11a2 100644 --- a/pkg/server/web/tree.go +++ b/server/web/tree.go @@ -19,9 +19,9 @@ import ( "regexp" "strings" - "github.com/astaxie/beego/pkg/core/utils" + "github.com/astaxie/beego/core/utils" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) var ( diff --git a/pkg/server/web/tree_test.go b/server/web/tree_test.go similarity index 99% rename from pkg/server/web/tree_test.go rename to server/web/tree_test.go index d3091de0..e72bc1f9 100644 --- a/pkg/server/web/tree_test.go +++ b/server/web/tree_test.go @@ -18,7 +18,7 @@ import ( "strings" "testing" - "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/server/web/context" ) type testinfo struct { diff --git a/pkg/server/web/unregroute_test.go b/server/web/unregroute_test.go similarity index 100% rename from pkg/server/web/unregroute_test.go rename to server/web/unregroute_test.go diff --git a/pkg/task/govenor_command.go b/task/govenor_command.go similarity index 97% rename from pkg/task/govenor_command.go rename to task/govenor_command.go index a9583970..15e25e43 100644 --- a/pkg/task/govenor_command.go +++ b/task/govenor_command.go @@ -21,7 +21,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/core/governor" + "github.com/astaxie/beego/core/governor" ) type listTaskCommand struct { diff --git a/pkg/task/governor_command_test.go b/task/governor_command_test.go similarity index 100% rename from pkg/task/governor_command_test.go rename to task/governor_command_test.go diff --git a/pkg/task/task.go b/task/task.go similarity index 100% rename from pkg/task/task.go rename to task/task.go diff --git a/pkg/task/task_test.go b/task/task_test.go similarity index 100% rename from pkg/task/task_test.go rename to task/task_test.go From d41abdb5e4ef7bc789c74dad58cebe315811ca09 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 8 Oct 2020 23:17:38 +0800 Subject: [PATCH 193/207] Remove scripts directory; update readme --- CONTRIBUTING.md | 13 +- README.md | 209 ++++++++++++++++++++++++++++++- build_info.go | 2 +- core/config/config.go | 3 - scripts/adapter.sh | 6 - scripts/prepare_etcd.sh | 8 -- scripts/test.sh | 14 --- scripts/test_docker_compose.yaml | 55 -------- 8 files changed, 214 insertions(+), 96 deletions(-) delete mode 100644 scripts/adapter.sh delete mode 100644 scripts/prepare_etcd.sh delete mode 100644 scripts/test.sh delete mode 100644 scripts/test_docker_compose.yaml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5035ae94..cb279cbb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,11 +17,14 @@ go get -u golang.org/x/tools/cmd/goimports go get -u github.com/gordonklaus/ineffassign ``` -And the go into project directory, run : +Put those lines into your pre-commit githook script: ```shell script -cp ./githook/pre-commit ./.git/hooks/pre-commit +goimports -w -format-only ./ + +ineffassign . + +staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./ ``` -This will add git hooks into .git/hooks. Or you can add it manually. ## Prepare middleware @@ -33,7 +36,7 @@ You can run: ```shell script docker-compose -f scripts/test_docker_compose.yml up -d ``` -Unit tests read addressed from environment, here is an example: +Unit tests read addresses from environment, here is an example: ```shell script export ORM_DRIVER=mysql export ORM_SOURCE="beego:test@tcp(192.168.0.105:13306)/orm_test?charset=utf8" @@ -86,5 +89,5 @@ documenting your bug report or improvement proposal. If it does, it never hurts to add a quick "+1" or "I have this problem too". This will help prioritize the most common problems and requests. -Also if you don't know how to use it. please make sure you have read though +Also, if you don't know how to use it. please make sure you have read through the docs in http://beego.me/docs \ No newline at end of file diff --git a/README.md b/README.md index de8f0063..934fc429 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,12 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature ## Quick Start +###### Please see [Documentation](http://beego.me/docs) for more. + +###### [beego-example](https://github.com/beego-dev/beego-example) + +### Web Application + #### Create `hello` directory, cd `hello` directory mkdir hello @@ -25,10 +31,10 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature ```go package main -import "github.com/astaxie/beego" +import "github.com/astaxie/beego/server/web" func main(){ - beego.Run() + web.Run() } ``` #### Build and run @@ -40,9 +46,204 @@ func main(){ Congratulations! You've just built your first **beego** app. -###### Please see [Documentation](http://beego.me/docs) for more. +### Using ORM module + +```go + +package main + +import ( + "github.com/astaxie/beego/client/orm" + "github.com/astaxie/beego/core/logs" + _ "github.com/go-sql-driver/mysql" +) + +// User - +type User struct { + ID int `orm:"column(id)"` + Name string `orm:"column(name)"` +} + +func init() { + // need to register models in init + orm.RegisterModel(new(User)) + + // need to register db driver + orm.RegisterDriver("mysql", orm.DRMySQL) + + // need to register default database + orm.RegisterDataBase("default", "mysql", "beego:test@tcp(192.168.0.105:13306)/orm_test?charset=utf8") +} + +func main() { + // automatically build table + orm.RunSyncdb("default", false, true) + + // create orm object, and it will use `default` database + o := orm.NewOrm() + + // data + user := new(User) + user.Name = "mike" + + // insert data + id, err := o.Insert(user) + if err != nil { + logs.Info(err) + } + + // ... +} +``` + +### Using httplib as http client +```go +package main + +import ( + "github.com/astaxie/beego/client/httplib" + "github.com/astaxie/beego/core/logs" +) + +func main() { + // Get, more methods please read docs + req := httplib.Get("http://beego.me/") + str, err := req.String() + if err != nil { + logs.Error(err) + } + logs.Info(str) +} + +``` + +### Using config module + +```go +package main + +import ( + "context" + + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" +) + +var ( + ConfigFile = "./app.conf" +) + +func main() { + cfg, err := config.NewConfig("ini", ConfigFile) + if err != nil { + logs.Critical("An error occurred:", err) + panic(err) + } + res, _ := cfg.String(context.Background(), "name") + logs.Info("load config name is", res) +} +``` +### Using logs module +```go +package main + +import ( + "github.com/astaxie/beego/core/logs" +) + +func main() { + err := logs.SetLogger(logs.AdapterFile, `{"filename":"project.log","level":7,"maxlines":0,"maxsize":0,"daily":true,"maxdays":10,"color":true}`) + if err != nil { + panic(err) + } + logs.Info("hello beego") +} +``` +### Using timed task + +```go +package main + +import ( + "context" + "time" + + "github.com/astaxie/beego/core/logs" + "github.com/astaxie/beego/task" +) + +func main() { + // create a task + tk1 := task.NewTask("tk1", "0/3 * * * * *", func(ctx context.Context) error { logs.Info("tk1"); return nil }) + + // check task + err := tk1.Run(context.Background()) + if err != nil { + logs.Error(err) + } + + // add task to global todolist + task.AddTask("tk1", tk1) + + // start tasks + task.StartTask() + + // wait 12 second + time.Sleep(12 * time.Second) + defer task.StopTask() +} +``` + +### Using cache module + +```go +package main + +import ( + "context" + "time" + + "github.com/astaxie/beego/client/cache" + + // don't forget this + _ "github.com/astaxie/beego/client/cache/redis" + + "github.com/astaxie/beego/core/logs" +) + +func main() { + // create cache + bm, err := cache.NewCache("redis", `{"key":"default", "conn":":6379", "password":"123456", "dbNum":"0"}`) + if err != nil { + logs.Error(err) + } + + // put + isPut := bm.Put(context.Background(), "astaxie", 1, time.Second*10) + logs.Info(isPut) + + isPut = bm.Put(context.Background(), "hello", "world", time.Second*10) + logs.Info(isPut) + + // get + result, _ := bm.Get(context.Background(),"astaxie") + logs.Info(string(result.([]byte))) + + multiResult, _ := bm.GetMulti(context.Background(), []string{"astaxie", "hello"}) + for i := range multiResult { + logs.Info(string(multiResult[i].([]byte))) + } + + // isExist + isExist, _ := bm.IsExist(context.Background(), "astaxie") + logs.Info(isExist) + + // delete + isDelete := bm.Delete(context.Background(), "astaxie") + logs.Info(isDelete) +} +``` -###### [beego-example](https://github.com/beego-dev/beego-example) ## Features diff --git a/build_info.go b/build_info.go index 33287090..42f42c28 100644 --- a/build_info.go +++ b/build_info.go @@ -28,5 +28,5 @@ var ( const ( // VERSION represent beego web framework version. - VERSION = "1.12.2" + VERSION = "2.0.0-alpha" ) diff --git a/core/config/config.go b/core/config/config.go index 0891e571..deac8e3e 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -184,17 +184,14 @@ func (c *BaseConfiger) Strings(ctx context.Context, key string) ([]string, error return strings.Split(res, ";"), nil } -// TODO remove this before release v2.0.0 func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error { return errors.New("unsupported operation") } -// TODO remove this before release v2.0.0 func (c *BaseConfiger) Sub(ctx context.Context, key string) (Configer, error) { return nil, errors.New("unsupported operation") } -// TODO remove this before release v2.0.0 func (c *BaseConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { // do nothing } diff --git a/scripts/adapter.sh b/scripts/adapter.sh deleted file mode 100644 index ce2d319a..00000000 --- a/scripts/adapter.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/sh - -# using pkg/adapter. Usually you want to migrate to V2 smoothly, you could running this script - -find ./ -name '*.go' -type f -exec sed -i '' -e 's/github.com\/astaxie\/beego/github.com\/astaxie\/beego\/pkg\/adapter/g' {} \; -find ./ -name '*.go' -type f -exec sed -i '' -e 's/"github.com\/astaxie\/beego\/pkg\/adapter"/beego "github.com\/astaxie\/beego\/pkg\/adapter"/g' {} \; diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh deleted file mode 100644 index d34c05a3..00000000 --- a/scripts/prepare_etcd.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -etcdctl put current.float 1.23 -etcdctl put current.bool true -etcdctl put current.int 11 -etcdctl put current.string hello -etcdctl put current.serialize.name test -etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh deleted file mode 100644 index 473a7066..00000000 --- a/scripts/test.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" up -d - -export ORM_DRIVER=mysql -export TZ=UTC -export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" - -go test "$(pwd)/..." - -# clear all container -docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" down - - diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml deleted file mode 100644 index f22b6deb..00000000 --- a/scripts/test_docker_compose.yaml +++ /dev/null @@ -1,55 +0,0 @@ -version: "3.8" -services: - redis: - container_name: "beego-redis" - image: redis - environment: - - ALLOW_EMPTY_PASSWORD=yes - ports: - - "6379:6379" - - mysql: - container_name: "beego-mysql" - image: mysql:5.7.30 - ports: - - "13306:3306" - environment: - - MYSQL_ROOT_PASSWORD=1q2w3e - - MYSQL_DATABASE=orm_test - - MYSQL_USER=beego - - MYSQL_PASSWORD=test - - postgresql: - container_name: "beego-postgresql" - image: bitnami/postgresql:latest - ports: - - "5432:5432" - environment: - - ALLOW_EMPTY_PASSWORD=yes - ssdb: - container_name: "beego-ssdb" - image: wendal/ssdb - ports: - - "8888:8888" - memcache: - container_name: "beego-memcache" - image: memcached - ports: - - "11211:11211" - etcd: - command: > - sh -c " - etcdctl put current.float 1.23 - && etcdctl put current.bool true - && etcdctl put current.int 11 - && etcdctl put current.string hello - && etcdctl put current.serialize.name test - " - container_name: "beego-etcd" - environment: - - ALLOW_NONE_AUTHENTICATION=yes -# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 - image: bitnami/etcd - ports: - - "2379:2379" - - "2380:2380" \ No newline at end of file From 34d6a733e9be883779fd0a5d70784a3bce2eafe9 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 11 Oct 2020 23:22:19 +0800 Subject: [PATCH 194/207] Support toml config --- core/config/config.go | 2 + core/config/error.go | 25 +++ core/config/toml/toml.go | 358 ++++++++++++++++++++++++++++++++ core/config/toml/toml_test.go | 380 ++++++++++++++++++++++++++++++++++ go.mod | 2 +- 5 files changed, 766 insertions(+), 1 deletion(-) create mode 100644 core/config/error.go create mode 100644 core/config/toml/toml.go create mode 100644 core/config/toml/toml_test.go diff --git a/core/config/config.go b/core/config/config.go index deac8e3e..cfbe5724 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -72,6 +72,8 @@ type Configer interface { DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 DefaultBool(ctx context.Context, key string, defaultVal bool) bool DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 + + // DIY return the original value DIY(ctx context.Context, key string) (interface{}, error) GetSection(ctx context.Context, section string) (map[string]string, error) diff --git a/core/config/error.go b/core/config/error.go new file mode 100644 index 00000000..e4636c45 --- /dev/null +++ b/core/config/error.go @@ -0,0 +1,25 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/pkg/errors" +) + +// now not all implementation return those error codes +var ( + KeyNotFoundError = errors.New("the key is not found") + InvalidValueTypeError = errors.New("the value is not expected type") +) diff --git a/core/config/toml/toml.go b/core/config/toml/toml.go new file mode 100644 index 00000000..47ea6a25 --- /dev/null +++ b/core/config/toml/toml.go @@ -0,0 +1,358 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toml + +import ( + "context" + "io/ioutil" + "os" + "strings" + + "github.com/pelletier/go-toml" + + "github.com/astaxie/beego/core/config" +) + +const keySeparator = "." + +type Config struct { + tree *toml.Tree +} + +// Parse accepts filename as the parameter +func (c *Config) Parse(filename string) (config.Configer, error) { + ctx, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return c.ParseData(ctx) +} + +func (c *Config) ParseData(data []byte) (config.Configer, error) { + t, err := toml.LoadBytes(data) + if err != nil { + return nil, err + } + return &configContainer{ + t: t, + }, nil + +} + +// configContainer support key looks like "a.b.c" +type configContainer struct { + t *toml.Tree +} + +// Set put key, val +func (c *configContainer) Set(ctx context.Context, key, val string) error { + path := strings.Split(key, keySeparator) + sub, err := subTree(c.t, path[0:len(path)-1]) + if err != nil { + return err + } + sub.Set(path[len(path)-1], val) + return nil +} + +// String return the value. +// return error if key not found or value is invalid type +func (c *configContainer) String(ctx context.Context, key string) (string, error) { + res, err := c.get(key) + + if err != nil { + return "", err + } + + if res == nil { + return "", config.KeyNotFoundError + } + + if str, ok := res.(string); ok { + return str, nil + } else { + return "", config.InvalidValueTypeError + } +} + +// Strings return []string +// return error if key not found or value is invalid type +func (c *configContainer) Strings(ctx context.Context, key string) ([]string, error) { + val, err := c.get(key) + + if err != nil { + return []string{}, err + } + if val == nil { + return []string{}, config.KeyNotFoundError + } + if arr, ok := val.([]interface{}); ok { + res := make([]string, 0, len(arr)) + for _, ele := range arr { + if str, ok := ele.(string); ok { + res = append(res, str) + } else { + return []string{}, config.InvalidValueTypeError + } + } + return res, nil + } else { + return []string{}, config.InvalidValueTypeError + } +} + +// Int return int value +// return error if key not found or value is invalid type +func (c *configContainer) Int(ctx context.Context, key string) (int, error) { + val, err := c.Int64(ctx, key) + return int(val), err +} + +// Int64 return int64 value +// return error if key not found or value is invalid type +func (c *configContainer) Int64(ctx context.Context, key string) (int64, error) { + res, err := c.get(key) + if err != nil { + return 0, err + } + if res == nil { + return 0, config.KeyNotFoundError + } + if i, ok := res.(int); ok { + return int64(i), nil + } else if i64, ok := res.(int64); ok { + return i64, nil + } else { + return 0, config.InvalidValueTypeError + } +} + +// bool return bool value +// return error if key not found or value is invalid type +func (c *configContainer) Bool(ctx context.Context, key string) (bool, error) { + + res, err := c.get(key) + + if err != nil { + return false, err + } + + if res == nil { + return false, config.KeyNotFoundError + } + if b, ok := res.(bool); ok { + return b, nil + } else { + return false, config.InvalidValueTypeError + } +} + +// Float return float value +// return error if key not found or value is invalid type +func (c *configContainer) Float(ctx context.Context, key string) (float64, error) { + res, err := c.get(key) + if err != nil { + return 0, err + } + + if res == nil { + return 0, config.KeyNotFoundError + } + + if f, ok := res.(float64); ok { + return f, nil + } else { + return 0, config.InvalidValueTypeError + } +} + +// DefaultString return string value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if str, ok := res.(string); ok { + return str + } else { + return defaultVal + } +} + +// DefaultStrings return []string +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + val, err := c.get(key) + if err != nil { + return defaultVal + } + if arr, ok := val.([]interface{}); ok { + res := make([]string, 0, len(arr)) + for _, ele := range arr { + if str, ok := ele.(string); ok { + res = append(res, str) + } else { + return defaultVal + } + } + return res + } else { + return defaultVal + } +} + +// DefaultInt return int value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + return int(c.DefaultInt64(ctx, key, int64(defaultVal))) +} + +// DefaultInt64 return int64 value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if i, ok := res.(int); ok { + return int64(i) + } else if i64, ok := res.(int64); ok { + return i64 + } else { + return defaultVal + } +} + +// DefaultBool return bool value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if b, ok := res.(bool); ok { + return b + } else { + return defaultVal + } +} + +// DefaultFloat return float value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if f, ok := res.(float64); ok { + return f + } else { + return defaultVal + } +} + +// DIY returns the original value +func (c *configContainer) DIY(ctx context.Context, key string) (interface{}, error) { + return c.get(key) +} + +// GetSection return error if the value is not valid toml doc +func (c *configContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { + val, err := subTree(c.t, strings.Split(section, keySeparator)) + if err != nil { + return map[string]string{}, err + } + m := val.ToMap() + res := make(map[string]string, len(m)) + for k, v := range m { + res[k] = config.ToString(v) + } + return res, nil +} + +func (c *configContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + if len(prefix) > 0 { + t, err := subTree(c.t, strings.Split(prefix, keySeparator)) + if err != nil { + return err + } + return t.Unmarshal(obj) + } + return c.t.Unmarshal(obj) +} + +// Sub return sub configer +// return error if key not found or the value is not a sub doc +func (c *configContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + val, err := subTree(c.t, strings.Split(key, keySeparator)) + if err != nil { + return nil, err + } + return &configContainer{ + t: val, + }, nil +} + +// OnChange do nothing +func (c *configContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing +} + +// SaveConfigFile create or override the file +func (c *configContainer) SaveConfigFile(ctx context.Context, filename string) error { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + _, err = c.t.WriteTo(f) + return err +} + +func (c *configContainer) get(key string) (interface{}, error) { + if len(key) == 0 { + return nil, config.KeyNotFoundError + } + + segs := strings.Split(key, keySeparator) + t, err := subTree(c.t, segs[0:len(segs)-1]) + + if err != nil { + return nil, err + } + return t.Get(segs[len(segs)-1]), nil +} + +func subTree(t *toml.Tree, path []string) (*toml.Tree, error) { + res := t + for i := 0; i < len(path); i++ { + if subTree, ok := res.Get(path[i]).(*toml.Tree); ok { + res = subTree + } else { + return nil, config.InvalidValueTypeError + } + } + if res == nil { + return nil, config.KeyNotFoundError + } + return res, nil +} + +func init() { + config.Register("toml", &Config{}) +} diff --git a/core/config/toml/toml_test.go b/core/config/toml/toml_test.go new file mode 100644 index 00000000..2af15596 --- /dev/null +++ b/core/config/toml/toml_test.go @@ -0,0 +1,380 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toml + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/core/config" +) + +func TestConfig_Parse(t *testing.T) { + // file not found + cfg := &Config{} + _, err := cfg.Parse("invalid_file_name.txt") + assert.NotNil(t, err) +} + +func TestConfig_ParseData(t *testing.T) { + data := ` +name="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) +} + +func TestConfigContainer_Bool(t *testing.T) { + data := ` +Man=true +Woman="true" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Bool(context.Background(), "Man") + assert.Nil(t, err) + assert.True(t, val) + + _, err = c.Bool(context.Background(), "Woman") + assert.NotNil(t, err) + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_DefaultBool(t *testing.T) { + data := ` +Man=true +Woman="false" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultBool(context.Background(), "Man11", true) + assert.True(t, val) + + val = c.DefaultBool(context.Background(), "Man", false) + assert.True(t, val) + + val = c.DefaultBool(context.Background(), "Woman", true) + assert.True(t, val) +} + +func TestConfigContainer_DefaultFloat(t *testing.T) { + data := ` +Price=12.3 +PriceInvalid="12.3" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultFloat(context.Background(), "Price", 11.2) + assert.Equal(t, 12.3, val) + + val = c.DefaultFloat(context.Background(), "Price11", 11.2) + assert.Equal(t, 11.2, val) + + val = c.DefaultFloat(context.Background(), "PriceInvalid", 11.2) + assert.Equal(t, 11.2, val) +} + +func TestConfigContainer_DefaultInt(t *testing.T) { + data := ` +Age=12 +AgeInvalid="13" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultInt(context.Background(), "Age", 11) + assert.Equal(t, 12, val) + + val = c.DefaultInt(context.Background(), "Price11", 11) + assert.Equal(t, 11, val) + + val = c.DefaultInt(context.Background(), "PriceInvalid", 11) + assert.Equal(t, 11, val) +} + +func TestConfigContainer_DefaultString(t *testing.T) { + data := ` +Name="Tom" +NameInvalid=13 +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultString(context.Background(), "Name", "Jerry") + assert.Equal(t, "Tom", val) + + val = c.DefaultString(context.Background(), "Name11", "Jerry") + assert.Equal(t, "Jerry", val) + + val = c.DefaultString(context.Background(), "NameInvalid", "Jerry") + assert.Equal(t, "Jerry", val) +} + +func TestConfigContainer_DefaultStrings(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultStrings(context.Background(), "Name", []string{"Jerry"}) + assert.Equal(t, []string{"Tom", "Jerry"}, val) + + val = c.DefaultStrings(context.Background(), "Name11", []string{"Jerry"}) + assert.Equal(t, []string{"Jerry"}, val) + + val = c.DefaultStrings(context.Background(), "NameInvalid", []string{"Jerry"}) + assert.Equal(t, []string{"Jerry"}, val) +} + +func TestConfigContainer_DIY(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + _, err = c.DIY(context.Background(), "Name") + assert.Nil(t, err) +} + +func TestConfigContainer_Float(t *testing.T) { + data := ` +Price=12.3 +PriceInvalid="12.3" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Float(context.Background(), "Price") + assert.Nil(t, err) + assert.Equal(t, 12.3, val) + + _, err = c.Float(context.Background(), "Price11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Float(context.Background(), "PriceInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_Int(t *testing.T) { + data := ` +Age=12 +AgeInvalid="13" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Int(context.Background(), "Age") + assert.Nil(t, err) + assert.Equal(t, 12, val) + + _, err = c.Int(context.Background(), "Age11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Int(context.Background(), "AgeInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_GetSection(t *testing.T) { + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + m, err := c.GetSection(context.Background(), "servers") + assert.Nil(t, err) + assert.NotNil(t, m) + assert.Equal(t, 2, len(m)) +} + +func TestConfigContainer_String(t *testing.T) { + data := ` +Name="Tom" +NameInvalid=13 +[Person] +Name="Jerry" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.String(context.Background(), "Name") + assert.Nil(t, err) + assert.Equal(t, "Tom", val) + + _, err = c.String(context.Background(), "Name11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.String(context.Background(), "NameInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) + + val, err = c.String(context.Background(), "Person.Name") + assert.Nil(t, err) + assert.Equal(t, "Jerry", val) +} + +func TestConfigContainer_Strings(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Strings(context.Background(), "Name") + assert.Nil(t, err) + assert.Equal(t, []string{"Tom", "Jerry"}, val) + + _, err = c.Strings(context.Background(), "Name11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Strings(context.Background(), "NameInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_Set(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + err = c.Set(context.Background(), "Age", "11") + assert.Nil(t, err) + age, err := c.String(context.Background(), "Age") + assert.Nil(t, err) + assert.Equal(t, "11", age) +} + +func TestConfigContainer_SubAndMushall(t *testing.T) { + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + sub, err := c.Sub(context.Background(), "servers") + assert.Nil(t, err) + assert.NotNil(t, sub) + + sub, err = sub.Sub(context.Background(), "alpha") + assert.Nil(t, err) + assert.NotNil(t, sub) + ip, err := sub.String(context.Background(), "ip") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", ip) + + svr := &Server{} + err = sub.Unmarshaler(context.Background(), "", svr) + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", svr.Ip) + + svr = &Server{} + err = c.Unmarshaler(context.Background(), "servers.alpha", svr) + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", svr.Ip) +} + +func TestConfigContainer_SaveConfigFile(t *testing.T) { + filename := "test_config.toml" + path := os.TempDir() + string(os.PathSeparator) + filename + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + + fmt.Println(path) + + assert.Nil(t, err) + assert.NotNil(t, c) + + sub, err := c.Sub(context.Background(), "servers") + assert.Nil(t, err) + + err = sub.SaveConfigFile(context.Background(), path) + assert.Nil(t, err) +} + +type Server struct { + Ip string `toml:"ip"` +} diff --git a/go.mod b/go.mod index ab7f5e39..697c9951 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mitchellh/mapstructure v1.3.3 github.com/opentracing/opentracing-go v1.2.0 - github.com/pelletier/go-toml v1.2.0 // indirect + github.com/pelletier/go-toml v1.2.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.7.0 github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 From 2572094a8df1edb3e52a3d0ccd5be54c8b6cdefa Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 13 Oct 2020 22:32:41 +0800 Subject: [PATCH 195/207] remove config API's context parameter --- adapter/config.go | 36 ++++++------ adapter/config/adapter.go | 72 ++++++++++++------------ core/config/base_config_test.go | 24 ++++---- core/config/config.go | 82 ++++++++++++++-------------- core/config/etcd/config.go | 49 +++++------------ core/config/etcd/config_test.go | 36 +++++------- core/config/fake.go | 32 +++++------ core/config/ini.go | 46 ++++++++-------- core/config/ini_test.go | 20 +++---- core/config/json/json.go | 59 ++++++++++---------- core/config/json/json_test.go | 45 ++++++++------- core/config/toml/toml.go | 43 +++++++-------- core/config/toml/toml_test.go | 83 ++++++++++++++-------------- core/config/xml/xml.go | 59 ++++++++++---------- core/config/xml/xml_test.go | 31 +++++------ core/config/yaml/yaml.go | 64 +++++++++++----------- core/config/yaml/yaml_test.go | 31 +++++------ server/web/config.go | 97 ++++++++++++++++----------------- server/web/config_test.go | 12 ++-- server/web/hooks.go | 9 ++- server/web/parser.go | 11 ++-- server/web/templatefunc.go | 13 ++--- 22 files changed, 457 insertions(+), 497 deletions(-) diff --git a/adapter/config.go b/adapter/config.go index 46f965ee..6280b8f8 100644 --- a/adapter/config.go +++ b/adapter/config.go @@ -15,8 +15,6 @@ package adapter import ( - context2 "context" - "github.com/astaxie/beego/adapter/session" newCfg "github.com/astaxie/beego/core/config" "github.com/astaxie/beego/server/web" @@ -74,54 +72,54 @@ type beegoAppConfig struct { } func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(context2.Background(), BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(context2.Background(), key, val) + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) } return nil } func (b *beegoAppConfig) String(key string) string { - if v, err := b.innerConfig.String(context2.Background(), BConfig.RunMode+"::"+key); v != "" && err != nil { + if v, err := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" && err != nil { return v } - res, _ := b.innerConfig.String(context2.Background(), key) + res, _ := b.innerConfig.String(key) return res } func (b *beegoAppConfig) Strings(key string) []string { - if v, err := b.innerConfig.Strings(context2.Background(), BConfig.RunMode+"::"+key); len(v) > 0 && err != nil { + if v, err := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 && err != nil { return v } - res, _ := b.innerConfig.Strings(context2.Background(), key) + res, _ := b.innerConfig.Strings(key) return res } func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Int(context2.Background(), key) + return b.innerConfig.Int(key) } func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Int64(context2.Background(), key) + return b.innerConfig.Int64(key) } func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Bool(context2.Background(), key) + return b.innerConfig.Bool(key) } func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Float(context2.Background(), key) + return b.innerConfig.Float(key) } func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { @@ -167,13 +165,13 @@ func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { } func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(context2.Background(), key) + return b.innerConfig.DIY(key) } func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(context2.Background(), section) + return b.innerConfig.GetSection(section) } func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(context2.Background(), filename) + return b.innerConfig.SaveConfigFile(filename) } diff --git a/adapter/config/adapter.go b/adapter/config/adapter.go index 6dc538ea..0a9e1d0c 100644 --- a/adapter/config/adapter.go +++ b/adapter/config/adapter.go @@ -15,8 +15,6 @@ package config import ( - "context" - "github.com/pkg/errors" "github.com/astaxie/beego/core/config" @@ -27,148 +25,148 @@ type newToOldConfigerAdapter struct { } func (c *newToOldConfigerAdapter) Set(key, val string) error { - return c.delegate.Set(context.Background(), key, val) + return c.delegate.Set(key, val) } func (c *newToOldConfigerAdapter) String(key string) string { - res, _ := c.delegate.String(context.Background(), key) + res, _ := c.delegate.String(key) return res } func (c *newToOldConfigerAdapter) Strings(key string) []string { - res, _ := c.delegate.Strings(context.Background(), key) + res, _ := c.delegate.Strings(key) return res } func (c *newToOldConfigerAdapter) Int(key string) (int, error) { - return c.delegate.Int(context.Background(), key) + return c.delegate.Int(key) } func (c *newToOldConfigerAdapter) Int64(key string) (int64, error) { - return c.delegate.Int64(context.Background(), key) + return c.delegate.Int64(key) } func (c *newToOldConfigerAdapter) Bool(key string) (bool, error) { - return c.delegate.Bool(context.Background(), key) + return c.delegate.Bool(key) } func (c *newToOldConfigerAdapter) Float(key string) (float64, error) { - return c.delegate.Float(context.Background(), key) + return c.delegate.Float(key) } func (c *newToOldConfigerAdapter) DefaultString(key string, defaultVal string) string { - return c.delegate.DefaultString(context.Background(), key, defaultVal) + return c.delegate.DefaultString(key, defaultVal) } func (c *newToOldConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { - return c.delegate.DefaultStrings(context.Background(), key, defaultVal) + return c.delegate.DefaultStrings(key, defaultVal) } func (c *newToOldConfigerAdapter) DefaultInt(key string, defaultVal int) int { - return c.delegate.DefaultInt(context.Background(), key, defaultVal) + return c.delegate.DefaultInt(key, defaultVal) } func (c *newToOldConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { - return c.delegate.DefaultInt64(context.Background(), key, defaultVal) + return c.delegate.DefaultInt64(key, defaultVal) } func (c *newToOldConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { - return c.delegate.DefaultBool(context.Background(), key, defaultVal) + return c.delegate.DefaultBool(key, defaultVal) } func (c *newToOldConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { - return c.delegate.DefaultFloat(context.Background(), key, defaultVal) + return c.delegate.DefaultFloat(key, defaultVal) } func (c *newToOldConfigerAdapter) DIY(key string) (interface{}, error) { - return c.delegate.DIY(context.Background(), key) + return c.delegate.DIY(key) } func (c *newToOldConfigerAdapter) GetSection(section string) (map[string]string, error) { - return c.delegate.GetSection(context.Background(), section) + return c.delegate.GetSection(section) } func (c *newToOldConfigerAdapter) SaveConfigFile(filename string) error { - return c.delegate.SaveConfigFile(context.Background(), filename) + return c.delegate.SaveConfigFile(filename) } type oldToNewConfigerAdapter struct { delegate Configer } -func (o *oldToNewConfigerAdapter) Set(ctx context.Context, key, val string) error { +func (o *oldToNewConfigerAdapter) Set(key, val string) error { return o.delegate.Set(key, val) } -func (o *oldToNewConfigerAdapter) String(ctx context.Context, key string) (string, error) { +func (o *oldToNewConfigerAdapter) String(key string) (string, error) { return o.delegate.String(key), nil } -func (o *oldToNewConfigerAdapter) Strings(ctx context.Context, key string) ([]string, error) { +func (o *oldToNewConfigerAdapter) Strings(key string) ([]string, error) { return o.delegate.Strings(key), nil } -func (o *oldToNewConfigerAdapter) Int(ctx context.Context, key string) (int, error) { +func (o *oldToNewConfigerAdapter) Int(key string) (int, error) { return o.delegate.Int(key) } -func (o *oldToNewConfigerAdapter) Int64(ctx context.Context, key string) (int64, error) { +func (o *oldToNewConfigerAdapter) Int64(key string) (int64, error) { return o.delegate.Int64(key) } -func (o *oldToNewConfigerAdapter) Bool(ctx context.Context, key string) (bool, error) { +func (o *oldToNewConfigerAdapter) Bool(key string) (bool, error) { return o.delegate.Bool(key) } -func (o *oldToNewConfigerAdapter) Float(ctx context.Context, key string) (float64, error) { +func (o *oldToNewConfigerAdapter) Float(key string) (float64, error) { return o.delegate.Float(key) } -func (o *oldToNewConfigerAdapter) DefaultString(ctx context.Context, key string, defaultVal string) string { +func (o *oldToNewConfigerAdapter) DefaultString(key string, defaultVal string) string { return o.delegate.DefaultString(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { +func (o *oldToNewConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { return o.delegate.DefaultStrings(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DefaultInt(ctx context.Context, key string, defaultVal int) int { +func (o *oldToNewConfigerAdapter) DefaultInt(key string, defaultVal int) int { return o.delegate.DefaultInt(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { +func (o *oldToNewConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { return o.delegate.DefaultInt64(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { +func (o *oldToNewConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { return o.delegate.DefaultBool(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { +func (o *oldToNewConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { return o.delegate.DefaultFloat(key, defaultVal) } -func (o *oldToNewConfigerAdapter) DIY(ctx context.Context, key string) (interface{}, error) { +func (o *oldToNewConfigerAdapter) DIY(key string) (interface{}, error) { return o.delegate.DIY(key) } -func (o *oldToNewConfigerAdapter) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (o *oldToNewConfigerAdapter) GetSection(section string) (map[string]string, error) { return o.delegate.GetSection(section) } -func (o *oldToNewConfigerAdapter) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { +func (o *oldToNewConfigerAdapter) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { return errors.New("unsupported operation, please use actual config.Configer") } -func (o *oldToNewConfigerAdapter) Sub(ctx context.Context, key string) (config.Configer, error) { +func (o *oldToNewConfigerAdapter) Sub(key string) (config.Configer, error) { return nil, errors.New("unsupported operation, please use actual config.Configer") } -func (o *oldToNewConfigerAdapter) OnChange(ctx context.Context, key string, fn func(value string)) { +func (o *oldToNewConfigerAdapter) OnChange(key string, fn func(value string)) { // do nothing } -func (o *oldToNewConfigerAdapter) SaveConfigFile(ctx context.Context, filename string) error { +func (o *oldToNewConfigerAdapter) SaveConfigFile(filename string) error { return o.delegate.SaveConfigFile(filename) } diff --git a/core/config/base_config_test.go b/core/config/base_config_test.go index 74cef184..74a669a7 100644 --- a/core/config/base_config_test.go +++ b/core/config/base_config_test.go @@ -24,38 +24,38 @@ import ( func TestBaseConfiger_DefaultBool(t *testing.T) { bc := newBaseConfier("true") - assert.True(t, bc.DefaultBool(context.Background(), "key1", false)) - assert.True(t, bc.DefaultBool(context.Background(), "key2", true)) + assert.True(t, bc.DefaultBool("key1", false)) + assert.True(t, bc.DefaultBool("key2", true)) } func TestBaseConfiger_DefaultFloat(t *testing.T) { bc := newBaseConfier("12.3") - assert.Equal(t, 12.3, bc.DefaultFloat(context.Background(), "key1", 0.1)) - assert.Equal(t, 0.1, bc.DefaultFloat(context.Background(), "key2", 0.1)) + assert.Equal(t, 12.3, bc.DefaultFloat("key1", 0.1)) + assert.Equal(t, 0.1, bc.DefaultFloat("key2", 0.1)) } func TestBaseConfiger_DefaultInt(t *testing.T) { bc := newBaseConfier("10") - assert.Equal(t, 10, bc.DefaultInt(context.Background(), "key1", 8)) - assert.Equal(t, 8, bc.DefaultInt(context.Background(), "key2", 8)) + assert.Equal(t, 10, bc.DefaultInt("key1", 8)) + assert.Equal(t, 8, bc.DefaultInt("key2", 8)) } func TestBaseConfiger_DefaultInt64(t *testing.T) { bc := newBaseConfier("64") - assert.Equal(t, int64(64), bc.DefaultInt64(context.Background(), "key1", int64(8))) - assert.Equal(t, int64(8), bc.DefaultInt64(context.Background(), "key2", int64(8))) + assert.Equal(t, int64(64), bc.DefaultInt64("key1", int64(8))) + assert.Equal(t, int64(8), bc.DefaultInt64("key2", int64(8))) } func TestBaseConfiger_DefaultString(t *testing.T) { bc := newBaseConfier("Hello") - assert.Equal(t, "Hello", bc.DefaultString(context.Background(), "key1", "world")) - assert.Equal(t, "world", bc.DefaultString(context.Background(), "key2", "world")) + assert.Equal(t, "Hello", bc.DefaultString("key1", "world")) + assert.Equal(t, "world", bc.DefaultString("key2", "world")) } func TestBaseConfiger_DefaultStrings(t *testing.T) { bc := newBaseConfier("Hello;world") - assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings(context.Background(), "key1", []string{"world"})) - assert.Equal(t, []string{"world"}, bc.DefaultStrings(context.Background(), "key2", []string{"world"})) + assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings("key1", []string{"world"})) + assert.Equal(t, []string{"world"}, bc.DefaultStrings("key2", []string{"world"})) } func newBaseConfier(str1 string) *BaseConfiger { diff --git a/core/config/config.go b/core/config/config.go index cfbe5724..908c65a5 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -54,34 +54,34 @@ import ( // Configer defines how to get and set value from configuration raw data. type Configer interface { // support section::key type in given key when using ini type. - Set(ctx context.Context, key, val string) error + Set(key, val string) error // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - String(ctx context.Context, key string) (string, error) + String(key string) (string, error) // get string slice - Strings(ctx context.Context, key string) ([]string, error) - Int(ctx context.Context, key string) (int, error) - Int64(ctx context.Context, key string) (int64, error) - Bool(ctx context.Context, key string) (bool, error) - Float(ctx context.Context, key string) (float64, error) + Strings(key string) ([]string, error) + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultString(ctx context.Context, key string, defaultVal string) string + DefaultString(key string, defaultVal string) string // get string slice - DefaultStrings(ctx context.Context, key string, defaultVal []string) []string - DefaultInt(ctx context.Context, key string, defaultVal int) int - DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 - DefaultBool(ctx context.Context, key string, defaultVal bool) bool - DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 + DefaultStrings(key string, defaultVal []string) []string + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 // DIY return the original value - DIY(ctx context.Context, key string) (interface{}, error) + DIY(key string) (interface{}, error) - GetSection(ctx context.Context, section string) (map[string]string, error) + GetSection(section string) (map[string]string, error) - Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error - Sub(ctx context.Context, key string) (Configer, error) - OnChange(ctx context.Context, key string, fn func(value string)) - SaveConfigFile(ctx context.Context, filename string) error + Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error + Sub(key string) (Configer, error) + OnChange(key string, fn func(value string)) + SaveConfigFile(filename string) error } type BaseConfiger struct { @@ -95,7 +95,7 @@ func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error } } -func (c *BaseConfiger) Int(ctx context.Context, key string) (int, error) { +func (c *BaseConfiger) Int(key string) (int, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -103,7 +103,7 @@ func (c *BaseConfiger) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(res) } -func (c *BaseConfiger) Int64(ctx context.Context, key string) (int64, error) { +func (c *BaseConfiger) Int64(key string) (int64, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -111,7 +111,7 @@ func (c *BaseConfiger) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(res, 10, 64) } -func (c *BaseConfiger) Bool(ctx context.Context, key string) (bool, error) { +func (c *BaseConfiger) Bool(key string) (bool, error) { res, err := c.reader(context.TODO(), key) if err != nil { return false, err @@ -119,7 +119,7 @@ func (c *BaseConfiger) Bool(ctx context.Context, key string) (bool, error) { return ParseBool(res) } -func (c *BaseConfiger) Float(ctx context.Context, key string) (float64, error) { +func (c *BaseConfiger) Float(key string) (float64, error) { res, err := c.reader(context.TODO(), key) if err != nil { return 0, err @@ -129,8 +129,8 @@ func (c *BaseConfiger) Float(ctx context.Context, key string) (float64, error) { // DefaultString returns the string value for a given key. // if err != nil or value is empty return defaultval -func (c *BaseConfiger) DefaultString(ctx context.Context, key string, defaultVal string) string { - if res, err := c.String(ctx, key); res != "" && err == nil { +func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { + if res, err := c.String(key); res != "" && err == nil { return res } return defaultVal @@ -138,63 +138,63 @@ func (c *BaseConfiger) DefaultString(ctx context.Context, key string, defaultVal // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *BaseConfiger) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { - if res, err := c.Strings(ctx, key); len(res) > 0 && err == nil { +func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { + if res, err := c.Strings(key); len(res) > 0 && err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt(ctx context.Context, key string, defaultVal int) int { - if res, err := c.Int(ctx, key); err == nil { +func (c *BaseConfiger) DefaultInt(key string, defaultVal int) int { + if res, err := c.Int(key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - if res, err := c.Int64(ctx, key); err == nil { +func (c *BaseConfiger) DefaultInt64(key string, defaultVal int64) int64 { + if res, err := c.Int64(key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - if res, err := c.Bool(ctx, key); err == nil { +func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { + if res, err := c.Bool(key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - if res, err := c.Float(ctx, key); err == nil { +func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { + if res, err := c.Float(key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) String(ctx context.Context, key string) (string, error) { +func (c *BaseConfiger) String(key string) (string, error) { return c.reader(context.TODO(), key) } // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *BaseConfiger) Strings(ctx context.Context, key string) ([]string, error) { - res, err := c.String(nil, key) +func (c *BaseConfiger) Strings(key string) ([]string, error) { + res, err := c.String(key) if err != nil || res == "" { return nil, err } return strings.Split(res, ";"), nil } -func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error { +func (c *BaseConfiger) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { return errors.New("unsupported operation") } -func (c *BaseConfiger) Sub(ctx context.Context, key string) (Configer, error) { +func (c *BaseConfiger) Sub(key string) (Configer, error) { return nil, errors.New("unsupported operation") } -func (c *BaseConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { +func (c *BaseConfiger) OnChange(key string, fn func(value string)) { // do nothing } diff --git a/core/config/etcd/config.go b/core/config/etcd/config.go index 37dba9de..6c3d33d4 100644 --- a/core/config/etcd/config.go +++ b/core/config/etcd/config.go @@ -30,8 +30,6 @@ import ( "github.com/astaxie/beego/core/logs" ) -const etcdOpts = "etcdOpts" - type EtcdConfiger struct { prefix string client *clientv3.Client @@ -50,7 +48,7 @@ func newEtcdConfiger(client *clientv3.Client, prefix string) *EtcdConfiger { // reader is an general implementation that read config from etcd. func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { - resp, err := get(e.client, ctx, e.prefix+key) + resp, err := get(e.client, e.prefix+key) if err != nil { return "", err } @@ -64,29 +62,24 @@ func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { // Set do nothing and return an error // I think write data to remote config center is not a good practice -func (e *EtcdConfiger) Set(ctx context.Context, key, val string) error { +func (e *EtcdConfiger) Set(key, val string) error { return errors.New("Unsupported operation") } // DIY return the original response from etcd // be careful when you decide to use this -func (e *EtcdConfiger) DIY(ctx context.Context, key string) (interface{}, error) { - return get(e.client, context.TODO(), key) +func (e *EtcdConfiger) DIY(key string) (interface{}, error) { + return get(e.client, key) } // GetSection in this implementation, we use section as prefix -func (e *EtcdConfiger) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (e *EtcdConfiger) GetSection(section string) (map[string]string, error) { var ( resp *clientv3.GetResponse err error ) - if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { - opts = append(opts, clientv3.WithPrefix()) - resp, err = e.client.Get(context.TODO(), e.prefix+section, opts...) - } else { - resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix()) - } + resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix()) if err != nil { return nil, errors.WithMessage(err, "GetSection failed") @@ -98,15 +91,15 @@ func (e *EtcdConfiger) GetSection(ctx context.Context, section string) (map[stri return res, nil } -func (e *EtcdConfiger) SaveConfigFile(ctx context.Context, filename string) error { +func (e *EtcdConfiger) SaveConfigFile(filename string) error { return errors.New("Unsupported operation") } // Unmarshaler is not very powerful because we lost the type information when we get configuration from etcd // for example, when we got "5", we are not sure whether it's int 5, or it's string "5" // TODO(support more complicated decoder) -func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { - res, err := e.GetSection(ctx, prefix) +func (e *EtcdConfiger) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + res, err := e.GetSection(prefix) if err != nil { return errors.WithMessage(err, fmt.Sprintf("could not read config with prefix: %s", prefix)) } @@ -120,22 +113,18 @@ func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj inter } // Sub return an sub configer. -func (e *EtcdConfiger) Sub(ctx context.Context, key string) (config.Configer, error) { +func (e *EtcdConfiger) Sub(key string) (config.Configer, error) { return newEtcdConfiger(e.client, e.prefix+key), nil } // TODO remove this before release v2.0.0 -func (e *EtcdConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { +func (e *EtcdConfiger) OnChange(key string, fn func(value string)) { buildOptsFunc := func() []clientv3.OpOption { - if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { - opts = append(opts, clientv3.WithCreatedNotify()) - return opts - } return []clientv3.OpOption{} } - rch := e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + rch := e.client.Watch(context.Background(), e.prefix+key, buildOptsFunc()...) go func() { for { for resp := range rch { @@ -152,7 +141,7 @@ func (e *EtcdConfiger) OnChange(ctx context.Context, key string, fn func(value s } } time.Sleep(time.Second) - rch = e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + rch = e.client.Watch(context.Background(), e.prefix+key, buildOptsFunc()...) } }() @@ -188,16 +177,12 @@ func (provider *EtcdConfigerProvider) ParseData(data []byte) (config.Configer, e return newEtcdConfiger(client, ""), nil } -func get(client *clientv3.Client, ctx context.Context, key string) (*clientv3.GetResponse, error) { +func get(client *clientv3.Client, key string) (*clientv3.GetResponse, error) { var ( resp *clientv3.GetResponse err error ) - if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { - resp, err = client.Get(ctx, key, opts...) - } else { - resp, err = client.Get(ctx, key) - } + resp, err = client.Get(context.Background(), key) if err != nil { return nil, errors.WithMessage(err, fmt.Sprintf("read config from etcd with key %s failed", key)) @@ -205,10 +190,6 @@ func get(client *clientv3.Client, ctx context.Context, key string) (*clientv3.Ge return resp, err } -func WithEtcdOption(ctx context.Context, opts ...clientv3.OpOption) context.Context { - return context.WithValue(ctx, etcdOpts, opts) -} - func init() { config.Register("json", &EtcdConfigerProvider{}) } diff --git a/core/config/etcd/config_test.go b/core/config/etcd/config_test.go index 7ccf6b96..6d0bb793 100644 --- a/core/config/etcd/config_test.go +++ b/core/config/etcd/config_test.go @@ -15,7 +15,6 @@ package etcd import ( - "context" "encoding/json" "os" "testing" @@ -25,11 +24,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestWithEtcdOption(t *testing.T) { - ctx := WithEtcdOption(context.Background(), clientv3.WithPrefix()) - assert.NotNil(t, ctx.Value(etcdOpts)) -} - func TestEtcdConfigerProvider_Parse(t *testing.T) { provider := &EtcdConfigerProvider{} cfger, err := provider.Parse(readEtcdConfig()) @@ -42,59 +36,59 @@ func TestEtcdConfiger(t *testing.T) { provider := &EtcdConfigerProvider{} cfger, _ := provider.Parse(readEtcdConfig()) - subCfger, err := cfger.Sub(nil, "sub.") + subCfger, err := cfger.Sub("sub.") assert.Nil(t, err) assert.NotNil(t, subCfger) - subSubCfger, err := subCfger.Sub(nil, "sub.") + subSubCfger, err := subCfger.Sub("sub.") assert.NotNil(t, subSubCfger) assert.Nil(t, err) - str, err := subSubCfger.String(nil, "key1") + str, err := subSubCfger.String("key1") assert.Nil(t, err) assert.Equal(t, "sub.sub.key", str) // we cannot test it - subSubCfger.OnChange(context.Background(), "watch", func(value string) { + subSubCfger.OnChange("watch", func(value string) { // do nothing }) - defStr := cfger.DefaultString(nil, "not_exit", "default value") + defStr := cfger.DefaultString("not_exit", "default value") assert.Equal(t, "default value", defStr) - defInt64 := cfger.DefaultInt64(nil, "not_exit", -1) + defInt64 := cfger.DefaultInt64("not_exit", -1) assert.Equal(t, int64(-1), defInt64) - defInt := cfger.DefaultInt(nil, "not_exit", -2) + defInt := cfger.DefaultInt("not_exit", -2) assert.Equal(t, -2, defInt) - defFlt := cfger.DefaultFloat(nil, "not_exit", 12.3) + defFlt := cfger.DefaultFloat("not_exit", 12.3) assert.Equal(t, 12.3, defFlt) - defBl := cfger.DefaultBool(nil, "not_exit", true) + defBl := cfger.DefaultBool("not_exit", true) assert.True(t, defBl) - defStrs := cfger.DefaultStrings(nil, "not_exit", []string{"hello"}) + defStrs := cfger.DefaultStrings("not_exit", []string{"hello"}) assert.Equal(t, []string{"hello"}, defStrs) - fl, err := cfger.Float(nil, "current.float") + fl, err := cfger.Float("current.float") assert.Nil(t, err) assert.Equal(t, 1.23, fl) - bl, err := cfger.Bool(nil, "current.bool") + bl, err := cfger.Bool("current.bool") assert.Nil(t, err) assert.True(t, bl) - it, err := cfger.Int(nil, "current.int") + it, err := cfger.Int("current.int") assert.Nil(t, err) assert.Equal(t, 11, it) - str, err = cfger.String(nil, "current.string") + str, err = cfger.String("current.string") assert.Nil(t, err) assert.Equal(t, "hello", str) tn := &TestEntity{} - err = cfger.Unmarshaler(context.Background(), "current.serialize.", tn) + err = cfger.Unmarshaler("current.serialize.", tn) assert.Nil(t, err) assert.Equal(t, "test", tn.Name) } diff --git a/core/config/fake.go b/core/config/fake.go index b606be01..332eaf1e 100644 --- a/core/config/fake.go +++ b/core/config/fake.go @@ -30,71 +30,71 @@ func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } -func (c *fakeConfigContainer) Set(ctx context.Context, key, val string) error { +func (c *fakeConfigContainer) Set(key, val string) error { c.data[strings.ToLower(key)] = val return nil } -func (c *fakeConfigContainer) Int(ctx context.Context, key string) (int, error) { +func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } -func (c *fakeConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - v, err := c.Int(ctx, key) +func (c *fakeConfigContainer) DefaultInt(key string, defaultVal int) int { + v, err := c.Int(key) if err != nil { return defaultVal } return v } -func (c *fakeConfigContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *fakeConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } -func (c *fakeConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - v, err := c.Int64(ctx, key) +func (c *fakeConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { + v, err := c.Int64(key) if err != nil { return defaultVal } return v } -func (c *fakeConfigContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *fakeConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getData(key)) } -func (c *fakeConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - v, err := c.Bool(ctx, key) +func (c *fakeConfigContainer) DefaultBool(key string, defaultVal bool) bool { + v, err := c.Bool(key) if err != nil { return defaultVal } return v } -func (c *fakeConfigContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *fakeConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } -func (c *fakeConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - v, err := c.Float(ctx, key) +func (c *fakeConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { + v, err := c.Float(key) if err != nil { return defaultVal } return v } -func (c *fakeConfigContainer) DIY(ctx context.Context, key string) (interface{}, error) { +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } return nil, errors.New("key not find") } -func (c *fakeConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } -func (c *fakeConfigContainer) SaveConfigFile(ctx context.Context, filename string) error { +func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") } diff --git a/core/config/ini.go b/core/config/ini.go index cc67e4cd..a78f0170 100644 --- a/core/config/ini.go +++ b/core/config/ini.go @@ -238,14 +238,14 @@ type IniConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *IniConfigContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *IniConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - v, err := c.Bool(ctx, key) +func (c *IniConfigContainer) DefaultBool(key string, defaultVal bool) bool { + v, err := c.Bool(key) if err != nil { return defaultVal } @@ -253,14 +253,14 @@ func (c *IniConfigContainer) DefaultBool(ctx context.Context, key string, defaul } // Int returns the integer value for a given key. -func (c *IniConfigContainer) Int(ctx context.Context, key string) (int, error) { +func (c *IniConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getdata(key)) } // DefaultInt returns the integer value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - v, err := c.Int(ctx, key) +func (c *IniConfigContainer) DefaultInt(key string, defaultVal int) int { + v, err := c.Int(key) if err != nil { return defaultVal } @@ -268,14 +268,14 @@ func (c *IniConfigContainer) DefaultInt(ctx context.Context, key string, default } // Int64 returns the int64 value for a given key. -func (c *IniConfigContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *IniConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getdata(key), 10, 64) } // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - v, err := c.Int64(ctx, key) +func (c *IniConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { + v, err := c.Int64(key) if err != nil { return defaultVal } @@ -283,14 +283,14 @@ func (c *IniConfigContainer) DefaultInt64(ctx context.Context, key string, defau } // Float returns the float value for a given key. -func (c *IniConfigContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *IniConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getdata(key), 64) } // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - v, err := c.Float(ctx, key) +func (c *IniConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { + v, err := c.Float(key) if err != nil { return defaultVal } @@ -298,14 +298,14 @@ func (c *IniConfigContainer) DefaultFloat(ctx context.Context, key string, defau } // String returns the string value for a given key. -func (c *IniConfigContainer) String(ctx context.Context, key string) (string, error) { +func (c *IniConfigContainer) String(key string) (string, error) { return c.getdata(key), nil } // DefaultString returns the string value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { - v, err := c.String(nil, key) +func (c *IniConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) if v == "" || err != nil { return defaultVal } @@ -314,8 +314,8 @@ func (c *IniConfigContainer) DefaultString(ctx context.Context, key string, defa // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *IniConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { - v, err := c.String(nil, key) +func (c *IniConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) if v == "" || err != nil { return nil, err } @@ -324,8 +324,8 @@ func (c *IniConfigContainer) Strings(ctx context.Context, key string) ([]string, // DefaultStrings returns the []string value for a given key. // if err != nil return defaultVal -func (c *IniConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { - v, err := c.Strings(ctx, key) +func (c *IniConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) if v == nil || err != nil { return defaultVal } @@ -333,7 +333,7 @@ func (c *IniConfigContainer) DefaultStrings(ctx context.Context, key string, def } // GetSection returns map for the given section -func (c *IniConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v, nil } @@ -343,7 +343,7 @@ func (c *IniConfigContainer) GetSection(ctx context.Context, section string) (ma // SaveConfigFile save the config into file. // // BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. -func (c *IniConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { +func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -443,7 +443,7 @@ func (c *IniConfigContainer) SaveConfigFile(ctx context.Context, filename string // Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. -func (c *IniConfigContainer) Set(ctx context.Context, key, val string) error { +func (c *IniConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() if len(key) == 0 { @@ -471,7 +471,7 @@ func (c *IniConfigContainer) Set(ctx context.Context, key, val string) error { } // DIY returns the raw value by a given key. -func (c *IniConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { +func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } diff --git a/core/config/ini_test.go b/core/config/ini_test.go index d4972ddd..7daa0a6e 100644 --- a/core/config/ini_test.go +++ b/core/config/ini_test.go @@ -101,19 +101,19 @@ password = ${GOPATH} var value interface{} switch v.(type) { case int: - value, err = iniconf.Int(nil, k) + value, err = iniconf.Int(k) case int64: - value, err = iniconf.Int64(nil, k) + value, err = iniconf.Int64(k) case float64: - value, err = iniconf.Float(nil, k) + value, err = iniconf.Float(k) case bool: - value, err = iniconf.Bool(nil, k) + value, err = iniconf.Bool(k) case []string: - value, err = iniconf.Strings(nil, k) + value, err = iniconf.Strings(k) case string: - value, err = iniconf.String(nil, k) + value, err = iniconf.String(k) default: - value, err = iniconf.DIY(nil, k) + value, err = iniconf.DIY(k) } if err != nil { t.Fatalf("get key %q value fail,err %s", k, err) @@ -122,10 +122,10 @@ password = ${GOPATH} } } - if err = iniconf.Set(nil, "name", "astaxie"); err != nil { + if err = iniconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := iniconf.String(nil, "name") + res, _ := iniconf.String("name") if res != "astaxie" { t.Fatal("get name error") } @@ -171,7 +171,7 @@ name=mysql t.Fatal(err) } name := "newIniConfig.ini" - if err := cfg.SaveConfigFile(nil, name); err != nil { + if err := cfg.SaveConfigFile(name); err != nil { t.Fatal(err) } defer os.Remove(name) diff --git a/core/config/json/json.go b/core/config/json/json.go index f58e70f5..672d2787 100644 --- a/core/config/json/json.go +++ b/core/config/json/json.go @@ -15,7 +15,6 @@ package json import ( - "context" "encoding/json" "errors" "fmt" @@ -77,16 +76,16 @@ type JSONConfigContainer struct { sync.RWMutex } -func (c *JSONConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { - sub, err := c.sub(ctx, prefix) +func (c *JSONConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(prefix) if err != nil { return err } return mapstructure.Decode(sub, obj) } -func (c *JSONConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { - sub, err := c.sub(ctx, key) +func (c *JSONConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.sub(key) if err != nil { return nil, err } @@ -95,7 +94,7 @@ func (c *JSONConfigContainer) Sub(ctx context.Context, key string) (config.Confi }, nil } -func (c *JSONConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { +func (c *JSONConfigContainer) sub(key string) (map[string]interface{}, error) { if key == "" { return c.data, nil } @@ -111,12 +110,12 @@ func (c *JSONConfigContainer) sub(ctx context.Context, key string) (map[string]i return res, nil } -func (c *JSONConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { +func (c *JSONConfigContainer) OnChange(key string, fn func(value string)) { logs.Warn("unsupported operation") } // Bool returns the boolean value for a given key. -func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *JSONConfigContainer) Bool(key string) (bool, error) { val := c.getData(key) if val != nil { return config.ParseBool(val) @@ -126,15 +125,15 @@ func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *JSONConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - if v, err := c.Bool(ctx, key); err == nil { +func (c *JSONConfigContainer) DefaultBool(key string, defaultVal bool) bool { + if v, err := c.Bool(key); err == nil { return v } return defaultVal } // Int returns the integer value for a given key. -func (c *JSONConfigContainer) Int(ctx context.Context, key string) (int, error) { +func (c *JSONConfigContainer) Int(key string) (int, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -149,15 +148,15 @@ func (c *JSONConfigContainer) Int(ctx context.Context, key string) (int, error) // DefaultInt returns the integer value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - if v, err := c.Int(ctx, key); err == nil { +func (c *JSONConfigContainer) DefaultInt(key string, defaultVal int) int { + if v, err := c.Int(key); err == nil { return v } return defaultVal } // Int64 returns the int64 value for a given key. -func (c *JSONConfigContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *JSONConfigContainer) Int64(key string) (int64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -170,15 +169,15 @@ func (c *JSONConfigContainer) Int64(ctx context.Context, key string) (int64, err // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - if v, err := c.Int64(ctx, key); err == nil { +func (c *JSONConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := c.Int64(key); err == nil { return v } return defaultVal } // Float returns the float value for a given key. -func (c *JSONConfigContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *JSONConfigContainer) Float(key string) (float64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -191,15 +190,15 @@ func (c *JSONConfigContainer) Float(ctx context.Context, key string) (float64, e // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - if v, err := c.Float(ctx, key); err == nil { +func (c *JSONConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := c.Float(key); err == nil { return v } return defaultVal } // String returns the string value for a given key. -func (c *JSONConfigContainer) String(ctx context.Context, key string) (string, error) { +func (c *JSONConfigContainer) String(key string) (string, error) { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { @@ -211,17 +210,17 @@ func (c *JSONConfigContainer) String(ctx context.Context, key string) (string, e // DefaultString returns the string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { +func (c *JSONConfigContainer) DefaultString(key string, defaultVal string) string { // TODO FIXME should not use "" to replace non existence - if v, err := c.String(ctx, key); v != "" && err == nil { + if v, err := c.String(key); v != "" && err == nil { return v } return defaultVal } // Strings returns the []string value for a given key. -func (c *JSONConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { - stringVal, err := c.String(nil, key) +func (c *JSONConfigContainer) Strings(key string) ([]string, error) { + stringVal, err := c.String(key) if stringVal == "" || err != nil { return nil, err } @@ -230,15 +229,15 @@ func (c *JSONConfigContainer) Strings(ctx context.Context, key string) ([]string // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { - if v, err := c.Strings(ctx, key); v != nil && err == nil { +func (c *JSONConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + if v, err := c.Strings(key); v != nil && err == nil { return v } return defaultVal } // GetSection returns map for the given section -func (c *JSONConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil } @@ -246,7 +245,7 @@ func (c *JSONConfigContainer) GetSection(ctx context.Context, section string) (m } // SaveConfigFile save the config into file -func (c *JSONConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -262,7 +261,7 @@ func (c *JSONConfigContainer) SaveConfigFile(ctx context.Context, filename strin } // Set writes a new value for key. -func (c *JSONConfigContainer) Set(ctx context.Context, key, val string) error { +func (c *JSONConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -270,7 +269,7 @@ func (c *JSONConfigContainer) Set(ctx context.Context, key, val string) error { } // DIY returns the raw value by a given key. -func (c *JSONConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { val := c.getData(key) if val != nil { return val, nil diff --git a/core/config/json/json_test.go b/core/config/json/json_test.go index b615c19a..386cfdf1 100644 --- a/core/config/json/json_test.go +++ b/core/config/json/json_test.go @@ -15,7 +15,6 @@ package json import ( - "context" "fmt" "os" "testing" @@ -52,7 +51,7 @@ func TestJsonStartsWithArray(t *testing.T) { if err != nil { t.Fatal(err) } - rootArray, err := jsonconf.DIY(nil, "rootArray") + rootArray, err := jsonconf.DIY("rootArray") if err != nil { t.Error("array does not exist as element") } @@ -158,19 +157,19 @@ func TestJson(t *testing.T) { var value interface{} switch v.(type) { case int: - value, err = jsonconf.Int(nil, k) + value, err = jsonconf.Int(k) case int64: - value, err = jsonconf.Int64(nil, k) + value, err = jsonconf.Int64(k) case float64: - value, err = jsonconf.Float(nil, k) + value, err = jsonconf.Float(k) case bool: - value, err = jsonconf.Bool(nil, k) + value, err = jsonconf.Bool(k) case []string: - value, err = jsonconf.Strings(nil, k) + value, err = jsonconf.Strings(k) case string: - value, err = jsonconf.String(nil, k) + value, err = jsonconf.String(k) default: - value, err = jsonconf.DIY(nil, k) + value, err = jsonconf.DIY(k) } if err != nil { t.Fatalf("get key %q value fatal,%v err %s", k, v, err) @@ -179,16 +178,16 @@ func TestJson(t *testing.T) { } } - if err = jsonconf.Set(nil, "name", "astaxie"); err != nil { + if err = jsonconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := jsonconf.String(nil, "name") + res, _ := jsonconf.String("name") if res != "astaxie" { t.Fatal("get name error") } - if db, err := jsonconf.DIY(nil, "database"); err != nil { + if db, err := jsonconf.DIY("database"); err != nil { t.Fatal(err) } else if m, ok := db.(map[string]interface{}); !ok { t.Log(db) @@ -199,46 +198,46 @@ func TestJson(t *testing.T) { } } - if _, err := jsonconf.Int(nil, "unknown"); err == nil { + if _, err := jsonconf.Int("unknown"); err == nil { t.Error("unknown keys should return an error when expecting an Int") } - if _, err := jsonconf.Int64(nil, "unknown"); err == nil { + if _, err := jsonconf.Int64("unknown"); err == nil { t.Error("unknown keys should return an error when expecting an Int64") } - if _, err := jsonconf.Float(nil, "unknown"); err == nil { + if _, err := jsonconf.Float("unknown"); err == nil { t.Error("unknown keys should return an error when expecting a Float") } - if _, err := jsonconf.DIY(nil, "unknown"); err == nil { + if _, err := jsonconf.DIY("unknown"); err == nil { t.Error("unknown keys should return an error when expecting an interface{}") } - if val, _ := jsonconf.String(nil, "unknown"); val != "" { + if val, _ := jsonconf.String("unknown"); val != "" { t.Error("unknown keys should return an empty string when expecting a String") } - if _, err := jsonconf.Bool(nil, "unknown"); err == nil { + if _, err := jsonconf.Bool("unknown"); err == nil { t.Error("unknown keys should return an error when expecting a Bool") } - if !jsonconf.DefaultBool(nil, "unknown", true) { + if !jsonconf.DefaultBool("unknown", true) { t.Error("unknown keys with default value wrong") } - sub, err := jsonconf.Sub(context.Background(), "database") + sub, err := jsonconf.Sub("database") assert.Nil(t, err) assert.NotNil(t, sub) - sub, err = sub.Sub(context.Background(), "conns") + sub, err = sub.Sub("conns") assert.Nil(t, err) - maxCon, _ := sub.Int(context.Background(), "maxconnection") + maxCon, _ := sub.Int("maxconnection") assert.Equal(t, 12, maxCon) dbCfg := &DatabaseConfig{} - err = sub.Unmarshaler(context.Background(), "", dbCfg) + err = sub.Unmarshaler("", dbCfg) assert.Nil(t, err) assert.Equal(t, 12, dbCfg.MaxConnection) assert.True(t, dbCfg.Autoconnect) diff --git a/core/config/toml/toml.go b/core/config/toml/toml.go index 47ea6a25..96e1a200 100644 --- a/core/config/toml/toml.go +++ b/core/config/toml/toml.go @@ -15,7 +15,6 @@ package toml import ( - "context" "io/ioutil" "os" "strings" @@ -57,7 +56,7 @@ type configContainer struct { } // Set put key, val -func (c *configContainer) Set(ctx context.Context, key, val string) error { +func (c *configContainer) Set(key, val string) error { path := strings.Split(key, keySeparator) sub, err := subTree(c.t, path[0:len(path)-1]) if err != nil { @@ -69,7 +68,7 @@ func (c *configContainer) Set(ctx context.Context, key, val string) error { // String return the value. // return error if key not found or value is invalid type -func (c *configContainer) String(ctx context.Context, key string) (string, error) { +func (c *configContainer) String(key string) (string, error) { res, err := c.get(key) if err != nil { @@ -89,7 +88,7 @@ func (c *configContainer) String(ctx context.Context, key string) (string, error // Strings return []string // return error if key not found or value is invalid type -func (c *configContainer) Strings(ctx context.Context, key string) ([]string, error) { +func (c *configContainer) Strings(key string) ([]string, error) { val, err := c.get(key) if err != nil { @@ -115,14 +114,14 @@ func (c *configContainer) Strings(ctx context.Context, key string) ([]string, er // Int return int value // return error if key not found or value is invalid type -func (c *configContainer) Int(ctx context.Context, key string) (int, error) { - val, err := c.Int64(ctx, key) +func (c *configContainer) Int(key string) (int, error) { + val, err := c.Int64(key) return int(val), err } // Int64 return int64 value // return error if key not found or value is invalid type -func (c *configContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *configContainer) Int64(key string) (int64, error) { res, err := c.get(key) if err != nil { return 0, err @@ -141,7 +140,7 @@ func (c *configContainer) Int64(ctx context.Context, key string) (int64, error) // bool return bool value // return error if key not found or value is invalid type -func (c *configContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *configContainer) Bool(key string) (bool, error) { res, err := c.get(key) @@ -161,7 +160,7 @@ func (c *configContainer) Bool(ctx context.Context, key string) (bool, error) { // Float return float value // return error if key not found or value is invalid type -func (c *configContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *configContainer) Float(key string) (float64, error) { res, err := c.get(key) if err != nil { return 0, err @@ -180,7 +179,7 @@ func (c *configContainer) Float(ctx context.Context, key string) (float64, error // DefaultString return string value // return default value if key not found or value is invalid type -func (c *configContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { +func (c *configContainer) DefaultString(key string, defaultVal string) string { res, err := c.get(key) if err != nil { return defaultVal @@ -194,7 +193,7 @@ func (c *configContainer) DefaultString(ctx context.Context, key string, default // DefaultStrings return []string // return default value if key not found or value is invalid type -func (c *configContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { +func (c *configContainer) DefaultStrings(key string, defaultVal []string) []string { val, err := c.get(key) if err != nil { return defaultVal @@ -216,13 +215,13 @@ func (c *configContainer) DefaultStrings(ctx context.Context, key string, defaul // DefaultInt return int value // return default value if key not found or value is invalid type -func (c *configContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - return int(c.DefaultInt64(ctx, key, int64(defaultVal))) +func (c *configContainer) DefaultInt(key string, defaultVal int) int { + return int(c.DefaultInt64(key, int64(defaultVal))) } // DefaultInt64 return int64 value // return default value if key not found or value is invalid type -func (c *configContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { +func (c *configContainer) DefaultInt64(key string, defaultVal int64) int64 { res, err := c.get(key) if err != nil { return defaultVal @@ -238,7 +237,7 @@ func (c *configContainer) DefaultInt64(ctx context.Context, key string, defaultV // DefaultBool return bool value // return default value if key not found or value is invalid type -func (c *configContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { +func (c *configContainer) DefaultBool(key string, defaultVal bool) bool { res, err := c.get(key) if err != nil { return defaultVal @@ -252,7 +251,7 @@ func (c *configContainer) DefaultBool(ctx context.Context, key string, defaultVa // DefaultFloat return float value // return default value if key not found or value is invalid type -func (c *configContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { +func (c *configContainer) DefaultFloat(key string, defaultVal float64) float64 { res, err := c.get(key) if err != nil { return defaultVal @@ -265,12 +264,12 @@ func (c *configContainer) DefaultFloat(ctx context.Context, key string, defaultV } // DIY returns the original value -func (c *configContainer) DIY(ctx context.Context, key string) (interface{}, error) { +func (c *configContainer) DIY(key string) (interface{}, error) { return c.get(key) } // GetSection return error if the value is not valid toml doc -func (c *configContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *configContainer) GetSection(section string) (map[string]string, error) { val, err := subTree(c.t, strings.Split(section, keySeparator)) if err != nil { return map[string]string{}, err @@ -283,7 +282,7 @@ func (c *configContainer) GetSection(ctx context.Context, section string) (map[s return res, nil } -func (c *configContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { +func (c *configContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { if len(prefix) > 0 { t, err := subTree(c.t, strings.Split(prefix, keySeparator)) if err != nil { @@ -296,7 +295,7 @@ func (c *configContainer) Unmarshaler(ctx context.Context, prefix string, obj in // Sub return sub configer // return error if key not found or the value is not a sub doc -func (c *configContainer) Sub(ctx context.Context, key string) (config.Configer, error) { +func (c *configContainer) Sub(key string) (config.Configer, error) { val, err := subTree(c.t, strings.Split(key, keySeparator)) if err != nil { return nil, err @@ -307,12 +306,12 @@ func (c *configContainer) Sub(ctx context.Context, key string) (config.Configer, } // OnChange do nothing -func (c *configContainer) OnChange(ctx context.Context, key string, fn func(value string)) { +func (c *configContainer) OnChange(key string, fn func(value string)) { // do nothing } // SaveConfigFile create or override the file -func (c *configContainer) SaveConfigFile(ctx context.Context, filename string) error { +func (c *configContainer) SaveConfigFile(filename string) error { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { diff --git a/core/config/toml/toml_test.go b/core/config/toml/toml_test.go index 2af15596..20726f0d 100644 --- a/core/config/toml/toml_test.go +++ b/core/config/toml/toml_test.go @@ -15,7 +15,6 @@ package toml import ( - "context" "fmt" "os" "testing" @@ -52,11 +51,11 @@ Woman="true" assert.Nil(t, err) assert.NotNil(t, c) - val, err := c.Bool(context.Background(), "Man") + val, err := c.Bool("Man") assert.Nil(t, err) assert.True(t, val) - _, err = c.Bool(context.Background(), "Woman") + _, err = c.Bool("Woman") assert.NotNil(t, err) assert.Equal(t, config.InvalidValueTypeError, err) } @@ -71,13 +70,13 @@ Woman="false" assert.Nil(t, err) assert.NotNil(t, c) - val := c.DefaultBool(context.Background(), "Man11", true) + val := c.DefaultBool("Man11", true) assert.True(t, val) - val = c.DefaultBool(context.Background(), "Man", false) + val = c.DefaultBool("Man", false) assert.True(t, val) - val = c.DefaultBool(context.Background(), "Woman", true) + val = c.DefaultBool("Woman", true) assert.True(t, val) } @@ -91,13 +90,13 @@ PriceInvalid="12.3" assert.Nil(t, err) assert.NotNil(t, c) - val := c.DefaultFloat(context.Background(), "Price", 11.2) + val := c.DefaultFloat("Price", 11.2) assert.Equal(t, 12.3, val) - val = c.DefaultFloat(context.Background(), "Price11", 11.2) + val = c.DefaultFloat("Price11", 11.2) assert.Equal(t, 11.2, val) - val = c.DefaultFloat(context.Background(), "PriceInvalid", 11.2) + val = c.DefaultFloat("PriceInvalid", 11.2) assert.Equal(t, 11.2, val) } @@ -111,13 +110,13 @@ AgeInvalid="13" assert.Nil(t, err) assert.NotNil(t, c) - val := c.DefaultInt(context.Background(), "Age", 11) + val := c.DefaultInt("Age", 11) assert.Equal(t, 12, val) - val = c.DefaultInt(context.Background(), "Price11", 11) + val = c.DefaultInt("Price11", 11) assert.Equal(t, 11, val) - val = c.DefaultInt(context.Background(), "PriceInvalid", 11) + val = c.DefaultInt("PriceInvalid", 11) assert.Equal(t, 11, val) } @@ -131,13 +130,13 @@ NameInvalid=13 assert.Nil(t, err) assert.NotNil(t, c) - val := c.DefaultString(context.Background(), "Name", "Jerry") + val := c.DefaultString("Name", "Jerry") assert.Equal(t, "Tom", val) - val = c.DefaultString(context.Background(), "Name11", "Jerry") + val = c.DefaultString("Name11", "Jerry") assert.Equal(t, "Jerry", val) - val = c.DefaultString(context.Background(), "NameInvalid", "Jerry") + val = c.DefaultString("NameInvalid", "Jerry") assert.Equal(t, "Jerry", val) } @@ -151,13 +150,13 @@ NameInvalid="Tom" assert.Nil(t, err) assert.NotNil(t, c) - val := c.DefaultStrings(context.Background(), "Name", []string{"Jerry"}) + val := c.DefaultStrings("Name", []string{"Jerry"}) assert.Equal(t, []string{"Tom", "Jerry"}, val) - val = c.DefaultStrings(context.Background(), "Name11", []string{"Jerry"}) + val = c.DefaultStrings("Name11", []string{"Jerry"}) assert.Equal(t, []string{"Jerry"}, val) - val = c.DefaultStrings(context.Background(), "NameInvalid", []string{"Jerry"}) + val = c.DefaultStrings("NameInvalid", []string{"Jerry"}) assert.Equal(t, []string{"Jerry"}, val) } @@ -170,7 +169,7 @@ Name=["Tom", "Jerry"] assert.Nil(t, err) assert.NotNil(t, c) - _, err = c.DIY(context.Background(), "Name") + _, err = c.DIY("Name") assert.Nil(t, err) } @@ -184,14 +183,14 @@ PriceInvalid="12.3" assert.Nil(t, err) assert.NotNil(t, c) - val, err := c.Float(context.Background(), "Price") + val, err := c.Float("Price") assert.Nil(t, err) assert.Equal(t, 12.3, val) - _, err = c.Float(context.Background(), "Price11") + _, err = c.Float("Price11") assert.Equal(t, config.KeyNotFoundError, err) - _, err = c.Float(context.Background(), "PriceInvalid") + _, err = c.Float("PriceInvalid") assert.Equal(t, config.InvalidValueTypeError, err) } @@ -205,14 +204,14 @@ AgeInvalid="13" assert.Nil(t, err) assert.NotNil(t, c) - val, err := c.Int(context.Background(), "Age") + val, err := c.Int("Age") assert.Nil(t, err) assert.Equal(t, 12, val) - _, err = c.Int(context.Background(), "Age11") + _, err = c.Int("Age11") assert.Equal(t, config.KeyNotFoundError, err) - _, err = c.Int(context.Background(), "AgeInvalid") + _, err = c.Int("AgeInvalid") assert.Equal(t, config.InvalidValueTypeError, err) } @@ -234,7 +233,7 @@ func TestConfigContainer_GetSection(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, c) - m, err := c.GetSection(context.Background(), "servers") + m, err := c.GetSection("servers") assert.Nil(t, err) assert.NotNil(t, m) assert.Equal(t, 2, len(m)) @@ -252,17 +251,17 @@ Name="Jerry" assert.Nil(t, err) assert.NotNil(t, c) - val, err := c.String(context.Background(), "Name") + val, err := c.String("Name") assert.Nil(t, err) assert.Equal(t, "Tom", val) - _, err = c.String(context.Background(), "Name11") + _, err = c.String("Name11") assert.Equal(t, config.KeyNotFoundError, err) - _, err = c.String(context.Background(), "NameInvalid") + _, err = c.String("NameInvalid") assert.Equal(t, config.InvalidValueTypeError, err) - val, err = c.String(context.Background(), "Person.Name") + val, err = c.String("Person.Name") assert.Nil(t, err) assert.Equal(t, "Jerry", val) } @@ -277,14 +276,14 @@ NameInvalid="Tom" assert.Nil(t, err) assert.NotNil(t, c) - val, err := c.Strings(context.Background(), "Name") + val, err := c.Strings("Name") assert.Nil(t, err) assert.Equal(t, []string{"Tom", "Jerry"}, val) - _, err = c.Strings(context.Background(), "Name11") + _, err = c.Strings("Name11") assert.Equal(t, config.KeyNotFoundError, err) - _, err = c.Strings(context.Background(), "NameInvalid") + _, err = c.Strings("NameInvalid") assert.Equal(t, config.InvalidValueTypeError, err) } @@ -298,9 +297,9 @@ NameInvalid="Tom" assert.Nil(t, err) assert.NotNil(t, c) - err = c.Set(context.Background(), "Age", "11") + err = c.Set("Age", "11") assert.Nil(t, err) - age, err := c.String(context.Background(), "Age") + age, err := c.String("Age") assert.Nil(t, err) assert.Equal(t, "11", age) } @@ -323,24 +322,24 @@ func TestConfigContainer_SubAndMushall(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, c) - sub, err := c.Sub(context.Background(), "servers") + sub, err := c.Sub("servers") assert.Nil(t, err) assert.NotNil(t, sub) - sub, err = sub.Sub(context.Background(), "alpha") + sub, err = sub.Sub("alpha") assert.Nil(t, err) assert.NotNil(t, sub) - ip, err := sub.String(context.Background(), "ip") + ip, err := sub.String("ip") assert.Nil(t, err) assert.Equal(t, "10.0.0.1", ip) svr := &Server{} - err = sub.Unmarshaler(context.Background(), "", svr) + err = sub.Unmarshaler("", svr) assert.Nil(t, err) assert.Equal(t, "10.0.0.1", svr.Ip) svr = &Server{} - err = c.Unmarshaler(context.Background(), "servers.alpha", svr) + err = c.Unmarshaler("servers.alpha", svr) assert.Nil(t, err) assert.Equal(t, "10.0.0.1", svr.Ip) } @@ -368,10 +367,10 @@ func TestConfigContainer_SaveConfigFile(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, c) - sub, err := c.Sub(context.Background(), "servers") + sub, err := c.Sub("servers") assert.Nil(t, err) - err = sub.SaveConfigFile(context.Background(), path) + err = sub.SaveConfigFile(path) assert.Nil(t, err) } diff --git a/core/config/xml/xml.go b/core/config/xml/xml.go index 3b1a7051..70f0c23c 100644 --- a/core/config/xml/xml.go +++ b/core/config/xml/xml.go @@ -30,7 +30,6 @@ package xml import ( - "context" "encoding/xml" "errors" "fmt" @@ -87,16 +86,16 @@ type ConfigContainer struct { // So when you use // 1 // The "1" is a string, not int -func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { - sub, err := c.sub(ctx, prefix) +func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(prefix) if err != nil { return err } return mapstructure.Decode(sub, obj) } -func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { - sub, err := c.sub(ctx, key) +func (c *ConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.sub(key) if err != nil { return nil, err } @@ -107,7 +106,7 @@ func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, } -func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { +func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { if key == "" { return c.data, nil } @@ -122,12 +121,12 @@ func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]inter return res, nil } -func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { +func (c *ConfigContainer) OnChange(key string, fn func(value string)) { logs.Warn("Unsupported operation") } // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *ConfigContainer) Bool(key string) (bool, error) { if v := c.data[key]; v != nil { return config.ParseBool(v) } @@ -136,8 +135,8 @@ func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultVal -func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - v, err := c.Bool(ctx, key) +func (c *ConfigContainer) DefaultBool(key string, defaultVal bool) bool { + v, err := c.Bool(key) if err != nil { return defaultVal } @@ -145,14 +144,14 @@ func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVa } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { +func (c *ConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.data[key].(string)) } // DefaultInt returns the integer value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - v, err := c.Int(ctx, key) +func (c *ConfigContainer) DefaultInt(key string, defaultVal int) int { + v, err := c.Int(key) if err != nil { return defaultVal } @@ -160,14 +159,14 @@ func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *ConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.data[key].(string), 10, 64) } // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - v, err := c.Int64(ctx, key) +func (c *ConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { + v, err := c.Int64(key) if err != nil { return defaultVal } @@ -176,14 +175,14 @@ func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultV } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *ConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.data[key].(string), 64) } // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - v, err := c.Float(ctx, key) +func (c *ConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { + v, err := c.Float(key) if err != nil { return defaultVal } @@ -191,7 +190,7 @@ func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultV } // String returns the string value for a given key. -func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { +func (c *ConfigContainer) String(key string) (string, error) { if v, ok := c.data[key].(string); ok { return v, nil } @@ -200,8 +199,8 @@ func (c *ConfigContainer) String(ctx context.Context, key string) (string, error // DefaultString returns the string value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { - v, err := c.String(ctx, key) +func (c *ConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) if v == "" || err != nil { return defaultVal } @@ -209,8 +208,8 @@ func (c *ConfigContainer) DefaultString(ctx context.Context, key string, default } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { - v, err := c.String(ctx, key) +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) if v == "" || err != nil { return nil, err } @@ -219,8 +218,8 @@ func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, er // DefaultStrings returns the []string value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { - v, err := c.Strings(ctx, key) +func (c *ConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) if v == nil || err != nil { return defaultVal } @@ -228,7 +227,7 @@ func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaul } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section].(map[string]interface{}); ok { mapstr := make(map[string]string) for k, val := range v { @@ -240,7 +239,7 @@ func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[s } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -256,7 +255,7 @@ func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) ( } // Set writes a new value for key. -func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { +func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -264,7 +263,7 @@ func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil } diff --git a/core/config/xml/xml_test.go b/core/config/xml/xml_test.go index 0266e270..c6cf970d 100644 --- a/core/config/xml/xml_test.go +++ b/core/config/xml/xml_test.go @@ -15,7 +15,6 @@ package xml import ( - "context" "fmt" "os" "testing" @@ -79,7 +78,7 @@ func TestXML(t *testing.T) { } var xmlsection map[string]string - xmlsection, err = xmlconf.GetSection(nil, "mysection") + xmlsection, err = xmlconf.GetSection("mysection") if err != nil { t.Fatal(err) } @@ -97,19 +96,19 @@ func TestXML(t *testing.T) { switch v.(type) { case int: - value, err = xmlconf.Int(nil, k) + value, err = xmlconf.Int(k) case int64: - value, err = xmlconf.Int64(nil, k) + value, err = xmlconf.Int64(k) case float64: - value, err = xmlconf.Float(nil, k) + value, err = xmlconf.Float(k) case bool: - value, err = xmlconf.Bool(nil, k) + value, err = xmlconf.Bool(k) case []string: - value, err = xmlconf.Strings(nil, k) + value, err = xmlconf.Strings(k) case string: - value, err = xmlconf.String(nil, k) + value, err = xmlconf.String(k) default: - value, err = xmlconf.DIY(nil, k) + value, err = xmlconf.DIY(k) } if err != nil { t.Errorf("get key %q value fatal,%v err %s", k, v, err) @@ -119,35 +118,35 @@ func TestXML(t *testing.T) { } - if err = xmlconf.Set(nil, "name", "astaxie"); err != nil { + if err = xmlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - res, _ := xmlconf.String(context.Background(), "name") + res, _ := xmlconf.String("name") if res != "astaxie" { t.Fatal("get name error") } - sub, err := xmlconf.Sub(context.Background(), "mysection") + sub, err := xmlconf.Sub("mysection") assert.Nil(t, err) assert.NotNil(t, sub) - name, err := sub.String(context.Background(), "name") + name, err := sub.String("name") assert.Nil(t, err) assert.Equal(t, "MySection", name) - id, err := sub.Int(context.Background(), "id") + id, err := sub.Int("id") assert.Nil(t, err) assert.Equal(t, 1, id) sec := &Section{} - err = sub.Unmarshaler(context.Background(), "", sec) + err = sub.Unmarshaler("", sec) assert.Nil(t, err) assert.Equal(t, "MySection", sec.Name) sec = &Section{} - err = xmlconf.Unmarshaler(context.Background(), "mysection", sec) + err = xmlconf.Unmarshaler("mysection", sec) assert.Nil(t, err) assert.Equal(t, "MySection", sec.Name) diff --git a/core/config/yaml/yaml.go b/core/config/yaml/yaml.go index 6d9abb4e..71daabee 100644 --- a/core/config/yaml/yaml.go +++ b/core/config/yaml/yaml.go @@ -31,7 +31,6 @@ package yaml import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -41,10 +40,11 @@ import ( "strings" "sync" - "github.com/astaxie/beego/core/config" - "github.com/astaxie/beego/core/logs" "github.com/beego/goyaml2" "gopkg.in/yaml.v2" + + "github.com/astaxie/beego/core/config" + "github.com/astaxie/beego/core/logs" ) // Config is a yaml config parser and implements Config interface. @@ -126,8 +126,8 @@ type ConfigContainer struct { } // Unmarshaler is similar to Sub -func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { - sub, err := c.sub(ctx, prefix) +func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(prefix) if err != nil { return err } @@ -139,8 +139,8 @@ func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj in return yaml.Unmarshal(bytes, obj) } -func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { - sub, err := c.sub(ctx, key) +func (c *ConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.sub(key) if err != nil { return nil, err } @@ -149,7 +149,7 @@ func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, }, nil } -func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { +func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { tmpData := c.data keys := strings.Split(key, ".") for idx, k := range keys { @@ -171,13 +171,13 @@ func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]inter return tmpData, nil } -func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { +func (c *ConfigContainer) OnChange(key string, fn func(value string)) { // do nothing logs.Warn("Unsupported operation: OnChange") } // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { +func (c *ConfigContainer) Bool(key string) (bool, error) { v, err := c.getData(key) if err != nil { return false, err @@ -187,8 +187,8 @@ func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultVal -func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { - v, err := c.Bool(ctx, key) +func (c *ConfigContainer) DefaultBool(key string, defaultVal bool) bool { + v, err := c.Bool(key) if err != nil { return defaultVal } @@ -196,7 +196,7 @@ func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVa } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { +func (c *ConfigContainer) Int(key string) (int, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int); ok { @@ -209,8 +209,8 @@ func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { - v, err := c.Int(ctx, key) +func (c *ConfigContainer) DefaultInt(key string, defaultVal int) int { + v, err := c.Int(key) if err != nil { return defaultVal } @@ -218,7 +218,7 @@ func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { +func (c *ConfigContainer) Int64(key string) (int64, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int64); ok { @@ -229,8 +229,8 @@ func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { - v, err := c.Int64(ctx, key) +func (c *ConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { + v, err := c.Int64(key) if err != nil { return defaultVal } @@ -238,7 +238,7 @@ func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultV } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { +func (c *ConfigContainer) Float(key string) (float64, error) { if v, err := c.getData(key); err != nil { return 0.0, err } else if vv, ok := v.(float64); ok { @@ -253,8 +253,8 @@ func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { - v, err := c.Float(ctx, key) +func (c *ConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { + v, err := c.Float(key) if err != nil { return defaultVal } @@ -262,7 +262,7 @@ func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultV } // String returns the string value for a given key. -func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { +func (c *ConfigContainer) String(key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { return vv, nil @@ -273,8 +273,8 @@ func (c *ConfigContainer) String(ctx context.Context, key string) (string, error // DefaultString returns the string value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { - v, err := c.String(nil, key) +func (c *ConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) if v == "" || err != nil { return defaultVal } @@ -282,8 +282,8 @@ func (c *ConfigContainer) DefaultString(ctx context.Context, key string, default } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { - v, err := c.String(nil, key) +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) if v == "" || err != nil { return nil, err } @@ -292,8 +292,8 @@ func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, er // DefaultStrings returns the []string value for a given key. // if err != nil return defaultVal -func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { - v, err := c.Strings(ctx, key) +func (c *ConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) if v == nil || err != nil { return defaultVal } @@ -301,7 +301,7 @@ func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaul } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil @@ -310,7 +310,7 @@ func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[s } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -322,7 +322,7 @@ func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) ( } // Set writes a new value for key. -func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { +func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -330,7 +330,7 @@ func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { return c.getData(key) } diff --git a/core/config/yaml/yaml_test.go b/core/config/yaml/yaml_test.go index a7c3a92e..d18317db 100644 --- a/core/config/yaml/yaml_test.go +++ b/core/config/yaml/yaml_test.go @@ -15,7 +15,6 @@ package yaml import ( - "context" "fmt" "os" "testing" @@ -76,7 +75,7 @@ func TestYaml(t *testing.T) { t.Fatal(err) } - res, _ := yamlconf.String(nil, "appname") + res, _ := yamlconf.String("appname") if res != "beeapi" { t.Fatal("appname not equal to beeapi") } @@ -90,19 +89,19 @@ func TestYaml(t *testing.T) { switch v.(type) { case int: - value, err = yamlconf.Int(nil, k) + value, err = yamlconf.Int(k) case int64: - value, err = yamlconf.Int64(nil, k) + value, err = yamlconf.Int64(k) case float64: - value, err = yamlconf.Float(nil, k) + value, err = yamlconf.Float(k) case bool: - value, err = yamlconf.Bool(nil, k) + value, err = yamlconf.Bool(k) case []string: - value, err = yamlconf.Strings(nil, k) + value, err = yamlconf.Strings(k) case string: - value, err = yamlconf.String(nil, k) + value, err = yamlconf.String(k) default: - value, err = yamlconf.DIY(nil, k) + value, err = yamlconf.DIY(k) } if err != nil { t.Errorf("get key %q value fatal,%v err %s", k, v, err) @@ -112,35 +111,35 @@ func TestYaml(t *testing.T) { } - if err = yamlconf.Set(nil, "name", "astaxie"); err != nil { + if err = yamlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - res, _ = yamlconf.String(nil, "name") + res, _ = yamlconf.String("name") if res != "astaxie" { t.Fatal("get name error") } - sub, err := yamlconf.Sub(context.Background(), "user") + sub, err := yamlconf.Sub("user") assert.Nil(t, err) assert.NotNil(t, sub) - name, err := sub.String(context.Background(), "name") + name, err := sub.String("name") assert.Nil(t, err) assert.Equal(t, "tom", name) - age, err := sub.Int(context.Background(), "age") + age, err := sub.Int("age") assert.Nil(t, err) assert.Equal(t, 13, age) user := &User{} - err = sub.Unmarshaler(context.Background(), "", user) + err = sub.Unmarshaler("", user) assert.Nil(t, err) assert.Equal(t, "tom", user.Name) assert.Equal(t, 13, user.Age) user = &User{} - err = yamlconf.Unmarshaler(context.Background(), "user", user) + err = yamlconf.Unmarshaler("user", user) assert.Nil(t, err) assert.Equal(t, "tom", user.Name) assert.Equal(t, 13, user.Age) diff --git a/server/web/config.go b/server/web/config.go index 10dc9c97..47c9686e 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -15,7 +15,6 @@ package web import ( - context2 "context" "crypto/tls" "fmt" "os" @@ -301,11 +300,11 @@ func assignConfig(ac config.Configer) error { // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode - } else if runMode, err := ac.String(nil, "RunMode"); runMode != "" && err == nil { + } else if runMode, err := ac.String("RunMode"); runMode != "" && err == nil { BConfig.RunMode = runMode } - if sd, err := ac.String(nil, "StaticDir"); sd != "" && err == nil { + if sd, err := ac.String("StaticDir"); sd != "" && err == nil { BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) for _, v := range sds { @@ -317,7 +316,7 @@ func assignConfig(ac config.Configer) error { } } - if sgz, err := ac.String(nil, "StaticExtensionsToGzip"); sgz != "" && err == nil { + if sgz, err := ac.String("StaticExtensionsToGzip"); sgz != "" && err == nil { extensions := strings.Split(sgz, ",") fileExts := []string{} for _, ext := range extensions { @@ -335,15 +334,15 @@ func assignConfig(ac config.Configer) error { } } - if sfs, err := ac.Int(nil, "StaticCacheFileSize"); err == nil { + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { BConfig.WebConfig.StaticCacheFileSize = sfs } - if sfn, err := ac.Int(nil, "StaticCacheFileNum"); err == nil { + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { BConfig.WebConfig.StaticCacheFileNum = sfn } - if lo, err := ac.String(nil, "LogOutputs"); lo != "" && err == nil { + if lo, err := ac.String("LogOutputs"); lo != "" && err == nil { // if lo is not nil or empty // means user has set his own LogOutputs // clear the default setting to BConfig.Log.Outputs @@ -390,11 +389,11 @@ func assignSingleConfig(p interface{}, ac config.Configer) { name := pt.Field(i).Name switch pf.Kind() { case reflect.String: - pf.SetString(ac.DefaultString(nil, name, pf.String())) + pf.SetString(ac.DefaultString(name, pf.String())) case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(nil, name, pf.Int())) + pf.SetInt(ac.DefaultInt64(name, pf.Int())) case reflect.Bool: - pf.SetBool(ac.DefaultBool(nil, name, pf.Bool())) + pf.SetBool(ac.DefaultBool(name, pf.Bool())) case reflect.Struct: default: // do nothing here @@ -433,105 +432,105 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err return &beegoAppConfig{innerConfig: ac}, nil } -func (b *beegoAppConfig) Set(ctx context2.Context, key, val string) error { - if err := b.innerConfig.Set(nil, BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(nil, key, val) +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) } return nil } -func (b *beegoAppConfig) String(ctx context2.Context, key string) (string, error) { - if v, err := b.innerConfig.String(nil, BConfig.RunMode+"::"+key); v != "" && err == nil { +func (b *beegoAppConfig) String(key string) (string, error) { + if v, err := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" && err == nil { return v, nil } - return b.innerConfig.String(nil, key) + return b.innerConfig.String(key) } -func (b *beegoAppConfig) Strings(ctx context2.Context, key string) ([]string, error) { - if v, err := b.innerConfig.Strings(nil, BConfig.RunMode+"::"+key); len(v) > 0 && err == nil { +func (b *beegoAppConfig) Strings(key string) ([]string, error) { + if v, err := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 && err == nil { return v, nil } - return b.innerConfig.Strings(nil, key) + return b.innerConfig.Strings(key) } -func (b *beegoAppConfig) Int(ctx context2.Context, key string) (int, error) { - if v, err := b.innerConfig.Int(nil, BConfig.RunMode+"::"+key); err == nil { +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Int(nil, key) + return b.innerConfig.Int(key) } -func (b *beegoAppConfig) Int64(ctx context2.Context, key string) (int64, error) { - if v, err := b.innerConfig.Int64(nil, BConfig.RunMode+"::"+key); err == nil { +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Int64(nil, key) + return b.innerConfig.Int64(key) } -func (b *beegoAppConfig) Bool(ctx context2.Context, key string) (bool, error) { - if v, err := b.innerConfig.Bool(nil, BConfig.RunMode+"::"+key); err == nil { +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Bool(nil, key) + return b.innerConfig.Bool(key) } -func (b *beegoAppConfig) Float(ctx context2.Context, key string) (float64, error) { - if v, err := b.innerConfig.Float(nil, BConfig.RunMode+"::"+key); err == nil { +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { return v, nil } - return b.innerConfig.Float(nil, key) + return b.innerConfig.Float(key) } -func (b *beegoAppConfig) DefaultString(ctx context2.Context, key string, defaultVal string) string { - if v, err := b.String(nil, key); v != "" && err == nil { +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v, err := b.String(key); v != "" && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultStrings(ctx context2.Context, key string, defaultVal []string) []string { - if v, err := b.Strings(ctx, key); len(v) != 0 && err == nil { +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v, err := b.Strings(key); len(v) != 0 && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt(ctx context2.Context, key string, defaultVal int) int { - if v, err := b.Int(ctx, key); err == nil { +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt64(ctx context2.Context, key string, defaultVal int64) int64 { - if v, err := b.Int64(ctx, key); err == nil { +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultBool(ctx context2.Context, key string, defaultVal bool) bool { - if v, err := b.Bool(ctx, key); err == nil { +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultFloat(ctx context2.Context, key string, defaultVal float64) float64 { - if v, err := b.Float(ctx, key); err == nil { +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DIY(ctx context2.Context, key string) (interface{}, error) { - return b.innerConfig.DIY(nil, key) +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) } -func (b *beegoAppConfig) GetSection(ctx context2.Context, section string) (map[string]string, error) { - return b.innerConfig.GetSection(nil, section) +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) } -func (b *beegoAppConfig) SaveConfigFile(ctx context2.Context, filename string) error { - return b.innerConfig.SaveConfigFile(nil, filename) +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) } diff --git a/server/web/config_test.go b/server/web/config_test.go index 88fa8b8c..0129ebb4 100644 --- a/server/web/config_test.go +++ b/server/web/config_test.go @@ -111,12 +111,12 @@ func TestAssignConfig_02(t *testing.T) { func TestAssignConfig_03(t *testing.T) { jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) - ac.Set(nil, "AppName", "test_app") - ac.Set(nil, "RunMode", "online") - ac.Set(nil, "StaticDir", "download:down download2:down2") - ac.Set(nil, "StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") - ac.Set(nil, "StaticCacheFileSize", "87456") - ac.Set(nil, "StaticCacheFileNum", "1254") + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") assignConfig(ac) t.Logf("%#v", BConfig) diff --git a/server/web/hooks.go b/server/web/hooks.go index 090e45d3..58e2c0f3 100644 --- a/server/web/hooks.go +++ b/server/web/hooks.go @@ -1,7 +1,6 @@ package web import ( - context2 "context" "encoding/json" "mime" "net/http" @@ -48,7 +47,7 @@ func registerDefaultErrorHandler() error { func registerSession() error { if BConfig.WebConfig.Session.SessionOn { var err error - sessionConfig, err := AppConfig.String(nil, "sessionConfig") + sessionConfig, err := AppConfig.String("sessionConfig") conf := new(session.ManagerConfig) if sessionConfig == "" || err != nil { conf.CookieName = BConfig.WebConfig.Session.SessionName @@ -89,9 +88,9 @@ func registerTemplate() error { func registerGzip() error { if BConfig.EnableGzip { context.InitGzip( - AppConfig.DefaultInt(context2.Background(), "gzipMinLength", -1), - AppConfig.DefaultInt(context2.Background(), "gzipCompressLevel", -1), - AppConfig.DefaultStrings(context2.Background(), "includedMethods", []string{"GET"}), + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), ) } return nil diff --git a/server/web/parser.go b/server/web/parser.go index 1a8a33df..820c8b10 100644 --- a/server/web/parser.go +++ b/server/web/parser.go @@ -15,7 +15,6 @@ package web import ( - "context" "encoding/json" "errors" "fmt" @@ -222,7 +221,7 @@ func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.Meth func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { options := []param.MethodParamOption{} if cparam, ok := pc.params[name]; ok { - //Build param from comment info + // Build param from comment info name = cparam.name if cparam.required { options = append(options, param.IsRequired) @@ -359,10 +358,10 @@ filterLoop: methods := matches[2] if methods == "" { pc.methods = []string{"get"} - //pc.hasGet = true + // pc.hasGet = true } else { pc.methods = strings.Split(methods, ",") - //pc.hasGet = strings.Contains(methods, "get") + // pc.hasGet = strings.Contains(methods, "get") } pcs = append(pcs, pc) } else { @@ -517,7 +516,7 @@ func genRouterCode(pkgRealpath string) { } defer f.Close() - routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") + routersDir := AppConfig.DefaultString("routersdir", "routers") content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) @@ -586,7 +585,7 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { func getRouterDir(pkgRealpath string) string { dir := filepath.Dir(pkgRealpath) for { - routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") + routersDir := AppConfig.DefaultString("routersdir", "routers") d := filepath.Join(dir, routersDir) if utils.FileExists(d) { return d diff --git a/server/web/templatefunc.go b/server/web/templatefunc.go index f3301e50..53c99018 100644 --- a/server/web/templatefunc.go +++ b/server/web/templatefunc.go @@ -15,7 +15,6 @@ package web import ( - "context" "errors" "fmt" "html" @@ -161,17 +160,17 @@ func NotNil(a interface{}) (isNil bool) { func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": - value, err = AppConfig.String(context.Background(), key) + value, err = AppConfig.String(key) case "Bool": - value, err = AppConfig.Bool(context.Background(), key) + value, err = AppConfig.Bool(key) case "Int": - value, err = AppConfig.Int(context.Background(), key) + value, err = AppConfig.Int(key) case "Int64": - value, err = AppConfig.Int64(context.Background(), key) + value, err = AppConfig.Int64(key) case "Float": - value, err = AppConfig.Float(context.Background(), key) + value, err = AppConfig.Float(key) case "DIY": - value, err = AppConfig.DIY(context.Background(), key) + value, err = AppConfig.DIY(key) default: err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") } From 3fc21ae6ec4dd9913d00f85f1dd5e3c90312c2cf Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 14 Oct 2020 00:24:22 +0800 Subject: [PATCH 196/207] Upgrade toml version --- core/config/config.go | 4 ---- core/config/fake.go | 4 ++++ core/config/ini.go | 9 +++++++++ go.mod | 2 +- go.sum | 2 ++ server/web/config.go | 4 ++++ 6 files changed, 20 insertions(+), 5 deletions(-) diff --git a/core/config/config.go b/core/config/config.go index 908c65a5..a4a24fff 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -186,10 +186,6 @@ func (c *BaseConfiger) Strings(key string) ([]string, error) { return strings.Split(res, ";"), nil } -func (c *BaseConfiger) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { - return errors.New("unsupported operation") -} - func (c *BaseConfiger) Sub(key string) (Configer, error) { return nil, errors.New("unsupported operation") } diff --git a/core/config/fake.go b/core/config/fake.go index 332eaf1e..3f6f4682 100644 --- a/core/config/fake.go +++ b/core/config/fake.go @@ -98,6 +98,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") } +func (c *fakeConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + return errors.New("unsupported operation") +} + var _ Configer = new(fakeConfigContainer) // NewFakeConfig return a fake Configer diff --git a/core/config/ini.go b/core/config/ini.go index a78f0170..3d869eb4 100644 --- a/core/config/ini.go +++ b/core/config/ini.go @@ -27,6 +27,8 @@ import ( "strconv" "strings" "sync" + + "github.com/mitchellh/mapstructure" ) var ( @@ -505,6 +507,13 @@ func (c *IniConfigContainer) getdata(key string) string { return "" } +func (c *IniConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + if len(prefix) > 0 { + return errors.New("unsupported prefix params") + } + return mapstructure.Decode(c.data, opt) +} + func init() { Register("ini", &IniConfig{}) } diff --git a/go.mod b/go.mod index 697c9951..7527aa47 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mitchellh/mapstructure v1.3.3 github.com/opentracing/opentracing-go v1.2.0 - github.com/pelletier/go-toml v1.2.0 + github.com/pelletier/go-toml v1.8.1 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.7.0 github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 diff --git a/go.sum b/go.sum index 545dbae5..994d1ec4 100644 --- a/go.sum +++ b/go.sum @@ -157,6 +157,8 @@ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYr github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM= +github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterh/liner v1.0.1-0.20171122030339-3681c2a91233/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/server/web/config.go b/server/web/config.go index 47c9686e..404ac249 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -432,6 +432,10 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err return &beegoAppConfig{innerConfig: ac}, nil } +func (b *beegoAppConfig) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + return b.innerConfig.Unmarshaler(prefix, obj, opt...) +} + func (b *beegoAppConfig) Set(key, val string) error { if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { return b.innerConfig.Set(key, val) From c07acaebbc4fabc3a09d596634bda40c19a41008 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 14 Oct 2020 22:10:39 +0800 Subject: [PATCH 197/207] Support unmarshaler --- server/web/config.go | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/server/web/config.go b/server/web/config.go index 404ac249..15a2dffe 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -293,10 +293,38 @@ func parseConfig(appConfigPath string) (err error) { return assignConfig(AppConfig) } +// assignConfig is tricky. +// For 1.x, it use assignSingleConfig to parse the file +// but for 2.x, we use Unmarshaler method func assignConfig(ac config.Configer) error { + + parseConfigForV1(ac) + + err := ac.Unmarshaler("", BConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, fmt.Sprintf("Unmarshaler config file to BConfig failed. " + + "And if you are working on v1.x config file, please ignore this, err: %s", err)) + return err + } + + // init log + logs.Reset() + for adaptor, cfg := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, cfg) + if err != nil { + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, cfg, err.Error())) + return err + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + return nil +} + +func parseConfigForV1(ac config.Configer) { for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { assignSingleConfig(i, ac) } + // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode @@ -356,18 +384,6 @@ func assignConfig(ac config.Configer) error { } } } - - // init log - logs.Reset() - for adaptor, config := range BConfig.Log.Outputs { - err := logs.SetLogger(adaptor, config) - if err != nil { - fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) - } - } - logs.SetLogFuncCall(BConfig.Log.FileLineNum) - - return nil } func assignSingleConfig(p interface{}, ac config.Configer) { From c510926cb89d99fca4c22d967e93cbe14b82e948 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 18 Oct 2020 23:18:13 +0800 Subject: [PATCH 198/207] =?UTF-8?q?fix=204224=EF=BC=9Aform=20entity=20too?= =?UTF-8?q?=20large=20casue=20run=20out=20of=20memory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- adapter/context/input.go | 2 +- server/web/context/input.go | 5 +++-- server/web/router.go | 10 ++++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/adapter/context/input.go b/adapter/context/input.go index 4d62d3c1..51bb9ea5 100644 --- a/adapter/context/input.go +++ b/adapter/context/input.go @@ -266,7 +266,7 @@ func (input *BeegoInput) SetData(key, val interface{}) { // ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { - return (*context.BeegoInput)(input).ParseFormOrMulitForm(maxMemory) + return (*context.BeegoInput)(input).ParseFormOrMultiForm(maxMemory) } // Bind data from request.Form[key] to dest diff --git a/server/web/context/input.go b/server/web/context/input.go index 641e15cc..21cae0b0 100644 --- a/server/web/context/input.go +++ b/server/web/context/input.go @@ -420,10 +420,11 @@ func (input *BeegoInput) SetData(key, val interface{}) { input.data[key] = val } -// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type -func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { +// ParseFormOrMultiForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMultiForm(maxMemory int64) error { // Parse the body depending on the content type. if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, input.Context.Request.Body, maxMemory) if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { return errors.New("Error parsing request body:" + err.Error()) } diff --git a/server/web/router.go b/server/web/router.go index 0f383f99..ca0918a0 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -666,6 +666,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { + var err error startTime := time.Now() r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter @@ -718,12 +719,17 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } ctx.Input.CopyBody(p.cfg.MaxMemory) } - ctx.Input.ParseFormOrMulitForm(p.cfg.MaxMemory) + + err = ctx.Input.ParseFormOrMultiForm(p.cfg.MaxMemory) + if err != nil { + logs.Error(errors.New("payload too large")) + exception("413", ctx) + goto Admin + } } // session init if p.cfg.WebConfig.Session.SessionOn { - var err error ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) From cbb3de741d9307fb18b3b9cad0d430ea70849a97 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 18 Oct 2020 23:38:08 +0800 Subject: [PATCH 199/207] fix application/x-www-form-urlencoded request body oversize --- server/web/context/input.go | 2 +- server/web/router.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/web/context/input.go b/server/web/context/input.go index 21cae0b0..027ad527 100644 --- a/server/web/context/input.go +++ b/server/web/context/input.go @@ -423,8 +423,8 @@ func (input *BeegoInput) SetData(key, val interface{}) { // ParseFormOrMultiForm parseForm or parseMultiForm based on Content-type func (input *BeegoInput) ParseFormOrMultiForm(maxMemory int64) error { // Parse the body depending on the content type. + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, input.Context.Request.Body, maxMemory) if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { - input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, input.Context.Request.Body, maxMemory) if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { return errors.New("Error parsing request body:" + err.Error()) } diff --git a/server/web/router.go b/server/web/router.go index ca0918a0..ba70d340 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -722,7 +722,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { err = ctx.Input.ParseFormOrMultiForm(p.cfg.MaxMemory) if err != nil { - logs.Error(errors.New("payload too large")) + logs.Error(err) exception("413", ctx) goto Admin } From 3c487199995173c7f02f6286f1f22c3d0abf40c0 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 19 Oct 2020 00:22:55 +0800 Subject: [PATCH 200/207] complete condition --- server/web/router.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/web/router.go b/server/web/router.go index ba70d340..515abcb0 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -723,7 +723,11 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { err = ctx.Input.ParseFormOrMultiForm(p.cfg.MaxMemory) if err != nil { logs.Error(err) - exception("413", ctx) + if strings.Contains(err.Error(), `http: request body too large`) { + exception("413", ctx) + } else { + exception("500", ctx) + } goto Admin } } From 93bdf970680b48b2deee1cb8aea55b30b48c9b00 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 19 Oct 2020 21:04:45 +0800 Subject: [PATCH 201/207] Fix ini Unmarshall method --- core/config/ini.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/config/ini.go b/core/config/ini.go index 3d869eb4..93dd774a 100644 --- a/core/config/ini.go +++ b/core/config/ini.go @@ -511,7 +511,7 @@ func (c *IniConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ... if len(prefix) > 0 { return errors.New("unsupported prefix params") } - return mapstructure.Decode(c.data, opt) + return mapstructure.Decode(c.data, obj) } func init() { From d26683799a6b2de721530696dd6dfb6494ff1ea9 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 20 Oct 2020 22:06:24 +0800 Subject: [PATCH 202/207] add MaxUploadFile to provide more safety uploading controll --- server/web/config.go | 19 ++++++++++++------- server/web/context/input.go | 3 +-- server/web/router.go | 11 ++++++++++- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/server/web/config.go b/server/web/config.go index 15a2dffe..10138e63 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -43,12 +43,16 @@ type Config struct { RecoverFunc func(*context.Context, *Config) CopyRequestBody bool EnableGzip bool - MaxMemory int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig + // MaxMemory and MaxUploadSize are used to limit the request body + // if the request is not uploading file, MaxMemory is the max size of request body + // if the request is uploading file, MaxUploadSize is the max size of request body + MaxMemory int64 + MaxUploadSize int64 + EnableErrorsShow bool + EnableErrorsRender bool + Listen Listen + WebConfig WebConfig + Log LogConfig } // Listen holds for http and https related config @@ -215,6 +219,7 @@ func newBConfig() *Config { CopyRequestBody: false, EnableGzip: false, MaxMemory: 1 << 26, // 64MB + MaxUploadSize: 1 << 30, // 1GB EnableErrorsShow: true, EnableErrorsRender: true, Listen: Listen{ @@ -302,7 +307,7 @@ func assignConfig(ac config.Configer) error { err := ac.Unmarshaler("", BConfig) if err != nil { - _, _ = fmt.Fprintln(os.Stderr, fmt.Sprintf("Unmarshaler config file to BConfig failed. " + + _, _ = fmt.Fprintln(os.Stderr, fmt.Sprintf("Unmarshaler config file to BConfig failed. "+ "And if you are working on v1.x config file, please ignore this, err: %s", err)) return err } diff --git a/server/web/context/input.go b/server/web/context/input.go index 027ad527..504838a3 100644 --- a/server/web/context/input.go +++ b/server/web/context/input.go @@ -423,8 +423,7 @@ func (input *BeegoInput) SetData(key, val interface{}) { // ParseFormOrMultiForm parseForm or parseMultiForm based on Content-type func (input *BeegoInput) ParseFormOrMultiForm(maxMemory int64) error { // Parse the body depending on the content type. - input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, input.Context.Request.Body, maxMemory) - if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if input.IsUpload() { if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { return errors.New("Error parsing request body:" + err.Error()) } diff --git a/server/web/router.go b/server/web/router.go index 515abcb0..7bb89d82 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -710,7 +710,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } if r.Method != http.MethodGet && r.Method != http.MethodHead { - if p.cfg.CopyRequestBody && !ctx.Input.IsUpload() { + + if ctx.Input.IsUpload() { + ctx.Input.Context.Request.Body = http.MaxBytesReader(ctx.Input.Context.ResponseWriter, + ctx.Input.Context.Request.Body, + p.cfg.MaxUploadSize) + } else if p.cfg.CopyRequestBody { // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > p.cfg.MaxMemory { logs.Error(errors.New("payload too large")) @@ -718,6 +723,10 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { goto Admin } ctx.Input.CopyBody(p.cfg.MaxMemory) + } else { + ctx.Input.Context.Request.Body = http.MaxBytesReader(ctx.Input.Context.ResponseWriter, + ctx.Input.Context.Request.Body, + p.cfg.MaxMemory) } err = ctx.Input.ParseFormOrMultiForm(p.cfg.MaxMemory) From 7c61eb058f964b09f5245940438d8a2b7c71d10b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 21 Oct 2020 20:53:59 +0800 Subject: [PATCH 203/207] Change NewHttpServer API --- server/web/admin.go | 2 +- server/web/server.go | 11 +++++------ server/web/server_test.go | 7 +++---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/server/web/admin.go b/server/web/admin.go index a1c47e0c..1b06f486 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -109,7 +109,7 @@ func registerAdmin() error { servers: make([]*HttpServer, 0, 2), } beeAdminApp = &adminApp{ - HttpServer: NewHttpServerWithCfg(*BConfig), + HttpServer: NewHttpServerWithCfg(BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack beeAdminApp.Router("/", c, "get:AdminIndex") diff --git a/server/web/server.go b/server/web/server.go index f289fd9b..25841563 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -59,19 +59,18 @@ type HttpServer struct { // NewHttpSever returns a new beego application. // this method will use the BConfig as the configure to create HttpServer -// Be careful that when you update BConfig, the server's Cfg will not be updated +// Be careful that when you update BConfig, the server's Cfg will be updated too func NewHttpSever() *HttpServer { - return NewHttpServerWithCfg(*BConfig) + return NewHttpServerWithCfg(BConfig) } // NewHttpServerWithCfg will create an sever with specific cfg -func NewHttpServerWithCfg(cfg Config) *HttpServer { - cfgPtr := &cfg - cr := NewControllerRegisterWithCfg(cfgPtr) +func NewHttpServerWithCfg(cfg *Config) *HttpServer { + cr := NewControllerRegisterWithCfg(cfg) app := &HttpServer{ Handlers: cr, Server: &http.Server{}, - Cfg: cfgPtr, + Cfg: cfg, } return app diff --git a/server/web/server_test.go b/server/web/server_test.go index 45ab2d4f..0b0c601c 100644 --- a/server/web/server_test.go +++ b/server/web/server_test.go @@ -21,11 +21,10 @@ import ( ) func TestNewHttpServerWithCfg(t *testing.T) { - // we should make sure that update server's config won't change + BConfig.AppName = "Before" - svr := NewHttpServerWithCfg(*BConfig) + svr := NewHttpServerWithCfg(BConfig) svr.Cfg.AppName = "hello" - assert.NotEqual(t, "hello", BConfig.AppName) - assert.Equal(t, "Before", BConfig.AppName) + assert.Equal(t, "hello", BConfig.AppName) } From 05f4e0c146742affc33e95d8402c93919217711d Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 21 Oct 2020 22:12:25 +0800 Subject: [PATCH 204/207] support using json string to init session --- .../web/session/couchbase/sess_couchbase.go | 38 +++++-- .../session/couchbase/sess_couchbase_test.go | 43 ++++++++ server/web/session/ledis/ledis_session.go | 47 +++++--- .../web/session/ledis/ledis_session_test.go | 41 +++++++ server/web/session/redis/sess_redis.go | 94 ++++++++++------ server/web/session/redis/sess_redis_test.go | 16 +++ .../session/redis_cluster/redis_cluster.go | 92 ++++++++++------ .../redis_cluster/redis_cluster_test.go | 35 ++++++ .../redis_sentinel/sess_redis_sentinel.go | 103 +++++++++++------- .../sess_redis_sentinel_test.go | 18 ++- server/web/session/ssdb/sess_ssdb.go | 32 ++++-- server/web/session/ssdb/sess_ssdb_test.go | 41 +++++++ 12 files changed, 464 insertions(+), 136 deletions(-) create mode 100644 server/web/session/couchbase/sess_couchbase_test.go create mode 100644 server/web/session/ledis/ledis_session_test.go create mode 100644 server/web/session/redis_cluster/redis_cluster_test.go create mode 100644 server/web/session/ssdb/sess_ssdb_test.go diff --git a/server/web/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go index ddd46401..7f15956a 100644 --- a/server/web/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -34,6 +34,7 @@ package couchbase import ( "context" + "encoding/json" "net/http" "strings" "sync" @@ -57,9 +58,9 @@ type SessionStore struct { // Provider couchabse provided type Provider struct { maxlifetime int64 - savePath string - pool string - bucket string + SavePath string `json:"save_path"` + Pool string `json:"pool"` + Bucket string `json:"bucket"` b *couchbase.Bucket } @@ -115,17 +116,17 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite } func (cp *Provider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.savePath) + c, err := couchbase.Connect(cp.SavePath) if err != nil { return nil } - pool, err := c.GetPool(cp.pool) + pool, err := c.GetPool(cp.Pool) if err != nil { return nil } - bucket, err := pool.GetBucket(cp.bucket) + bucket, err := pool.GetBucket(cp.Bucket) if err != nil { return nil } @@ -135,18 +136,31 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// For v1.x e.g. http://host:port/, Pool, Bucket +// For v2.x, you should pass json string. +// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"} +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error { cp.maxlifetime = maxlifetime + cfg = strings.TrimSpace(cfg) + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + return json.Unmarshal([]byte(cfg), cp) + } else { + return cp.initOldStyle(cfg) + } +} + +// initOldStyle keep compatible with v1.x +func (cp *Provider) initOldStyle(savePath string) error { configs := strings.Split(savePath, ",") if len(configs) > 0 { - cp.savePath = configs[0] + cp.SavePath = configs[0] } if len(configs) > 1 { - cp.pool = configs[1] + cp.Pool = configs[1] } if len(configs) > 2 { - cp.bucket = configs[2] + cp.Bucket = configs[2] } return nil @@ -225,7 +239,7 @@ func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) ( return cs, nil } -// SessionDestroy Remove bucket in this couchbase +// SessionDestroy Remove Bucket in this couchbase func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() diff --git a/server/web/session/couchbase/sess_couchbase_test.go b/server/web/session/couchbase/sess_couchbase_test.go new file mode 100644 index 00000000..5959f9c3 --- /dev/null +++ b/server/web/session/couchbase/sess_couchbase_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package couchbase + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,Pool,Bucket` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, "Pool", cp.Pool) + assert.Equal(t, "Bucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, "mypool", cp.Pool) + assert.Equal(t, "mybucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/ledis/ledis_session.go b/server/web/session/ledis/ledis_session.go index a920ff7c..5b930fcd 100644 --- a/server/web/session/ledis/ledis_session.go +++ b/server/web/session/ledis/ledis_session.go @@ -3,6 +3,7 @@ package ledis import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -79,35 +80,51 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider ledis session provider type Provider struct { maxlifetime int64 - savePath string - db int + SavePath string `json:"save_path"` + Db int `json:"db"` } // SessionInit init ledis session // savepath like ledis server saveDataPath,pool size -// e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// v1.x e.g. 127.0.0.1:6379,100 +// v2.x you should pass a json string +// e.g. { "save_path": "my save path", "db": 100} +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { var err error lp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) == 1 { - lp.savePath = configs[0] - } else if len(configs) == 2 { - lp.savePath = configs[0] - lp.db, err = strconv.Atoi(configs[1]) - if err != nil { - return err - } + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err = json.Unmarshal([]byte(cfgStr), lp) + } else { + err = lp.initOldStyle(cfgStr) } + + if err != nil { + return err + } + cfg := new(config.Config) - cfg.DataDir = lp.savePath + cfg.DataDir = lp.SavePath var ledisInstance *ledis.Ledis ledisInstance, err = ledis.Open(cfg) if err != nil { return err } - c, err = ledisInstance.Select(lp.db) + c, err = ledisInstance.Select(lp.Db) + return err +} + +func (lp *Provider) initOldStyle(cfgStr string) error { + var err error + configs := strings.Split(cfgStr, ",") + if len(configs) == 1 { + lp.SavePath = configs[0] + } else if len(configs) == 2 { + lp.SavePath = configs[0] + lp.Db, err = strconv.Atoi(configs[1]) + } return err } diff --git a/server/web/session/ledis/ledis_session_test.go b/server/web/session/ledis/ledis_session_test.go new file mode 100644 index 00000000..1cfb3ed1 --- /dev/null +++ b/server/web/session/ledis/ledis_session_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ledis + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,100` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "db": 100} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go index 6ee28e2f..c6e3bcbb 100644 --- a/server/web/session/redis/sess_redis.go +++ b/server/web/session/redis/sess_redis.go @@ -34,6 +34,7 @@ package redis import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -110,48 +111,89 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *redis.Client + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client } // SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second -// e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30 +// v2.0 you should pass json string +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewClient(&redis.Options{ + Addr: rp.SavePath, + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping().Err() +} + +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) @@ -168,21 +210,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = redis.NewClient(&redis.Options{ - Addr: rp.savePath, - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - - return rp.poollist.Ping().Err() } // SessionRead read redis session by sid diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index 19c8c025..64dbc9f9 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -1,11 +1,15 @@ package redis import ( + "context" "fmt" "net/http" "net/http/httptest" "os" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/astaxie/beego/server/web/session" ) @@ -94,3 +98,15 @@ func TestRedis(t *testing.T) { sess.SessionRelease(nil, w) } + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go index 17653d56..d2971e71 100644 --- a/server/web/session/redis_cluster/redis_cluster.go +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -34,14 +34,16 @@ package redis_cluster import ( "context" + "encoding/json" "net/http" "strconv" "strings" "sync" "time" - "github.com/astaxie/beego/server/web/session" rediss "github.com/go-redis/redis/v7" + + "github.com/astaxie/beego/server/web/session" ) var redispder = &Provider{} @@ -109,48 +111,86 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis_cluster session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *rediss.ClusterClient + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *rediss.ClusterClient } // SessionInit init redis_cluster session -// savepath like redis server addr,pool size,password,dbnum +// cfgStr like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + return rp.poollist.Ping().Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) @@ -167,19 +207,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - return rp.poollist.Ping().Err() } // SessionRead read redis_cluster session by sid diff --git a/server/web/session/redis_cluster/redis_cluster_test.go b/server/web/session/redis_cluster/redis_cluster_test.go new file mode 100644 index 00000000..0192cd87 --- /dev/null +++ b/server/web/session/redis_cluster/redis_cluster_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis_cluster + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go index d68b8767..89d73b86 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -34,6 +34,7 @@ package redis_sentinel import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -110,58 +111,99 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis_sentinel session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *redis.Client - masterName string + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client + MasterName string `json:"master_name"` } // SessionInit init redis_sentinel session -// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// cfgStr like redis sentinel addr,pool size,password,dbnum,masterName // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + MasterName: rp.MasterName, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping().Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = DefaultPoolSize + rp.Poolsize = DefaultPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = DefaultPoolSize + rp.Poolsize = DefaultPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { if configs[4] != "" { - rp.masterName = configs[4] + rp.MasterName = configs[4] } else { - rp.masterName = "mymaster" + rp.MasterName = "mymaster" } } else { - rp.masterName = "mymaster" + rp.MasterName = "mymaster" } if len(configs) > 5 { timeout, err := strconv.Atoi(configs[4]) @@ -178,22 +220,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 7 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - - return rp.poollist.Ping().Err() } // SessionRead read redis_sentinel session by sid diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index e4822d11..f052a14a 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -1,9 +1,13 @@ package redis_sentinel import ( + "context" "net/http" "net/http/httptest" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/astaxie/beego/server/web/session" ) @@ -23,7 +27,7 @@ func TestRedisSentinel(t *testing.T) { t.Log(e) return } - //todo test if e==nil + // todo test if e==nil go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) @@ -88,3 +92,15 @@ func TestRedisSentinel(t *testing.T) { sess.SessionRelease(nil, w) } + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/ssdb/sess_ssdb.go b/server/web/session/ssdb/sess_ssdb.go index 9d1230d0..0adc41bd 100644 --- a/server/web/session/ssdb/sess_ssdb.go +++ b/server/web/session/ssdb/sess_ssdb.go @@ -2,6 +2,7 @@ package ssdb import ( "context" + "encoding/json" "errors" "net/http" "strconv" @@ -18,33 +19,48 @@ var ssdbProvider = &Provider{} // Provider holds ssdb client and configs type Provider struct { client *ssdb.Client - host string - port int + Host string `json:"host"` + Port int `json:"port"` maxLifetime int64 } func (p *Provider) connectInit() error { var err error - if p.host == "" || p.port == 0 { + if p.Host == "" || p.Port == 0 { return errors.New("SessionInit First") } - p.client, err = ssdb.Connect(p.host, p.port) + p.client, err = ssdb.Connect(p.Host, p.Port) return err } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error { p.maxLifetime = maxLifetime - address := strings.Split(savePath, ":") - p.host = address[0] + cfg = strings.TrimSpace(cfg) var err error - if p.port, err = strconv.Atoi(address[1]); err != nil { + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + err = json.Unmarshal([]byte(cfg), p) + } else { + err = p.initOldStyle(cfg) + } + if err != nil { return err } return p.connectInit() } +// for v1.x +func (p *Provider) initOldStyle(savePath string) error { + address := strings.Split(savePath, ":") + p.Host = address[0] + + var err error + p.Port, err = strconv.Atoi(address[1]) + return err +} + // SessionRead return a ssdb client session Store func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { diff --git a/server/web/session/ssdb/sess_ssdb_test.go b/server/web/session/ssdb/sess_ssdb_test.go new file mode 100644 index 00000000..3de5da0a --- /dev/null +++ b/server/web/session/ssdb/sess_ssdb_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssdb + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `localhost:8080` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) + + savePath = ` +{ "host": "localhost", "port": 8080} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) +} From 45260e4119fa5d2cb2ff9863dd8ab931fd55e3f0 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 24 Oct 2020 21:24:31 +0800 Subject: [PATCH 205/207] Add global instance for config module --- core/config/global.go | 115 +++++++++++++++++++++++++++++++++++++ core/config/global_test.go | 104 +++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 core/config/global.go create mode 100644 core/config/global_test.go diff --git a/core/config/global.go b/core/config/global.go new file mode 100644 index 00000000..c5c59ba7 --- /dev/null +++ b/core/config/global.go @@ -0,0 +1,115 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/astaxie/beego/core/logs" +) + +// We use this to simply application's development +// for most users, they only need to use those methods +var globalInstance Configer + +func init() { + // Ignore this error + err := InitGlobalInstance("ini", "config/app.conf") + if err != nil { + logs.Warn("init global config instance failed. If you donot use this, just ignore it. ", err) + } + +} + +// InitGlobalInstance will ini the global instance +// If you want to use specific implementation, don't forget to import it. +// e.g. _ import "github.com/astaxie/beego/core/config/etcd" +// err := InitGlobalInstance("etcd", "someconfig") +func InitGlobalInstance(name string, cfg string) error { + var err error + globalInstance, err = NewConfig(name, cfg) + return err +} + +// support section::key type in given key when using ini type. +func Set(key, val string) error { + return globalInstance.Set(key, val) +} + +// support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. +func String(key string) (string, error) { + return globalInstance.String(key) +} + +// get string slice +func Strings(key string) ([]string, error) { + return globalInstance.Strings(key) +} +func Int(key string) (int, error) { + return globalInstance.Int(key) +} +func Int64(key string) (int64, error) { + return globalInstance.Int64(key) +} +func Bool(key string) (bool, error) { + return globalInstance.Bool(key) +} +func Float(key string) (float64, error) { + return globalInstance.Float(key) +} + +// support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. +func DefaultString(key string, defaultVal string) string { + return globalInstance.DefaultString(key, defaultVal) +} + +// get string slice +func DefaultStrings(key string, defaultVal []string) []string { + return globalInstance.DefaultStrings(key, defaultVal) +} +func DefaultInt(key string, defaultVal int) int { + return globalInstance.DefaultInt(key, defaultVal) +} +func DefaultInt64(key string, defaultVal int64) int64 { + return globalInstance.DefaultInt64(key, defaultVal) +} +func DefaultBool(key string, defaultVal bool) bool { + return globalInstance.DefaultBool(key, defaultVal) +} +func DefaultFloat(key string, defaultVal float64) float64 { + return globalInstance.DefaultFloat(key, defaultVal) +} + +// DIY return the original value +func DIY(key string) (interface{}, error) { + return globalInstance.DIY(key) +} + +func GetSection(section string) (map[string]string, error) { + return globalInstance.GetSection(section) +} + +func Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + return globalInstance.Unmarshaler(prefix, obj, opt...) +} +func Sub(key string) (Configer, error) { + return globalInstance.Sub(key) +} + +func OnChange(key string, fn func(value string)) { + globalInstance.OnChange(key, fn) +} + +func SaveConfigFile(filename string) error { + return globalInstance.SaveConfigFile(filename) +} diff --git a/core/config/global_test.go b/core/config/global_test.go new file mode 100644 index 00000000..ff01b043 --- /dev/null +++ b/core/config/global_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGlobalInstance(t *testing.T) { + cfgStr := ` +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415926 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +developer="tom;jerry" +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + path := os.TempDir() + string(os.PathSeparator) + "test_global_instance.ini" + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(cfgStr) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove(path) + + err = InitGlobalInstance("ini", path) + assert.Nil(t, err) + + val, err := String("appname") + assert.Nil(t, err) + assert.Equal(t, "beeapi", val) + + val = DefaultString("appname__", "404") + assert.Equal(t, "404", val) + + vi, err := Int("httpport") + assert.Nil(t, err) + assert.Equal(t, 8080, vi) + vi = DefaultInt("httpport__", 404) + assert.Equal(t, 404, vi) + + vi64, err := Int64("mysqlport") + assert.Nil(t, err) + assert.Equal(t, int64(3600), vi64) + vi64 = DefaultInt64("mysqlport__", 404) + assert.Equal(t, int64(404), vi64) + + vf, err := Float("PI") + assert.Nil(t, err) + assert.Equal(t, 3.1415926, vf) + vf = DefaultFloat("PI__", 4.04) + assert.Equal(t, 4.04, vf) + + vb, err := Bool("copyrequestbody") + assert.Nil(t, err) + assert.True(t, vb) + + vb = DefaultBool("copyrequestbody__", false) + assert.False(t, vb) + + vss := DefaultStrings("developer__", []string{"tom", ""}) + assert.Equal(t, []string{"tom", ""}, vss) + + vss, err = Strings("developer") + assert.Nil(t, err) + assert.Equal(t, []string{"tom", "jerry"}, vss) +} From d07a1eaa8e8b917bb7330a04b0e1c4eb8cdb055f Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 27 Oct 2020 21:54:43 +0800 Subject: [PATCH 206/207] Add test for httplib --- client/httplib/httplib_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index f6be8571..88935715 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "context" "errors" "io/ioutil" "net" @@ -23,6 +24,8 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestResponse(t *testing.T) { @@ -284,3 +287,16 @@ func TestHeader(t *testing.T) { } t.Log(str) } + +// TestAddFilter make sure that AddFilters only work for the specific request +func TestAddFilter(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return next(ctx, req) + } + }) + + r := Get("http://beego.me") + assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) +} From b4396c97bb713bd450edb1cf01ced6095fb03755 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Thu, 5 Nov 2020 22:00:43 +0800 Subject: [PATCH 207/207] fix init error of global instance --- core/config/global.go | 12 ------------ core/config/ini.go | 11 ++++++++++- core/logs/log.go | 2 +- server/web/parser.go | 2 +- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/core/config/global.go b/core/config/global.go index c5c59ba7..5491fe2c 100644 --- a/core/config/global.go +++ b/core/config/global.go @@ -14,22 +14,10 @@ package config -import ( - "github.com/astaxie/beego/core/logs" -) - // We use this to simply application's development // for most users, they only need to use those methods var globalInstance Configer -func init() { - // Ignore this error - err := InitGlobalInstance("ini", "config/app.conf") - if err != nil { - logs.Warn("init global config instance failed. If you donot use this, just ignore it. ", err) - } - -} // InitGlobalInstance will ini the global instance // If you want to use specific implementation, don't forget to import it. diff --git a/core/config/ini.go b/core/config/ini.go index 93dd774a..4d17fb7a 100644 --- a/core/config/ini.go +++ b/core/config/ini.go @@ -29,6 +29,8 @@ import ( "sync" "github.com/mitchellh/mapstructure" + + "github.com/astaxie/beego/core/logs" ) var ( @@ -97,7 +99,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e break } - //It might be a good idea to throw a error on all unknonw errors? + // It might be a good idea to throw a error on all unknonw errors? if _, ok := err.(*os.PathError); ok { return nil, err } @@ -516,4 +518,11 @@ func (c *IniConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ... func init() { Register("ini", &IniConfig{}) + + err := InitGlobalInstance("ini", "config/app.conf") + if err != nil { + logs.Warn("init global config instance failed. If you donot use this, just ignore it. ", err) + } } + +// Ignore this error diff --git a/core/logs/log.go b/core/logs/log.go index b05abd3b..d5953dfb 100644 --- a/core/logs/log.go +++ b/core/logs/log.go @@ -687,7 +687,7 @@ func EnableFuncCallDepth(b bool) { // SetLogFuncCall set the CallDepth, default is 4 func SetLogFuncCall(b bool) { beeLogger.EnableFuncCallDepth(b) - beeLogger.SetLogFuncCallDepth(4) + beeLogger.SetLogFuncCallDepth(3) } // SetLogFuncCallDepth set log funcCallDepth diff --git a/server/web/parser.go b/server/web/parser.go index 820c8b10..c3434501 100644 --- a/server/web/parser.go +++ b/server/web/parser.go @@ -39,7 +39,7 @@ import ( var globalRouterTemplate = `package {{.routersDir}} import ( - "github.com/astaxie/beego" + beego "github.com/astaxie/beego/server/web" "github.com/astaxie/beego/server/web/context/param"{{.globalimport}} )