From ca9b21bb30e1d287f72a9e0febc6f007c545268b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 4 Jul 2020 13:41:19 +0000 Subject: [PATCH 001/207] Add docker-compose to support running test --- test.sh | 14 ++++++++++++++ test_docker_compose.yaml | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 test.sh create mode 100644 test_docker_compose.yaml diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..78928fea --- /dev/null +++ b/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +docker-compose -f test_docker_compose.yaml up -d + +export ORM_DRIVER=mysql +export TZ=UTC +export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + +go test ./... + +# clear all container +docker-compose -f test_docker_compose.yaml down + + diff --git a/test_docker_compose.yaml b/test_docker_compose.yaml new file mode 100644 index 00000000..54ca4097 --- /dev/null +++ b/test_docker_compose.yaml @@ -0,0 +1,39 @@ +version: "3.8" +services: + redis: + container_name: "beego-redis" + image: redis + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379" + + mysql: + container_name: "beego-mysql" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - ALLOW_EMPTY_PASSWORD=yes + ssdb: + container_name: "beego-ssdb" + image: wendal/ssdb + ports: + - "8888:8888" + memcache: + container_name: "beego-memcache" + image: memcached + ports: + - "11211:11211" + From 7c575585e9e305a28cc138d778f3f0bef5a6909a Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 15:27:12 +0100 Subject: [PATCH 002/207] added conditional json flag when trying to view healthchecks --- admin.go | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/admin.go b/admin.go index 3e538a0e..c5ae686d 100644 --- a/admin.go +++ b/admin.go @@ -279,9 +279,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - - rw.Header().Set("Content-Type", "application/json") - rw.Write(dataJSON) + execJSON(rw, dataJSON) return } @@ -295,7 +293,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { // Healthcheck is a http.Handler calling health checking and showing the result. // it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, _ *http.Request) { +func healthcheck(rw http.ResponseWriter, r *http.Request) { var ( result []string data = make(map[interface{}]interface{}) @@ -322,12 +320,44 @@ func healthcheck(rw http.ResponseWriter, _ *http.Request) { *resultList = append(*resultList, result) } + queryParams := r.URL.Query() + + if queryParams["json"] != nil { + + type Result map[string]interface{} + + response := make([]Result, len(*resultList)) + + for i, currentResult := range *resultList { + currentResultMap := make(Result) + currentResultMap["name"] = currentResult[0] + currentResultMap["message"] = currentResult[1] + currentResultMap["status"] = currentResult[2] + response[i] = currentResultMap + } + + JSONResponse, err := json.Marshal(response) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + execJSON(rw, JSONResponse) + } + + return + } + content["Data"] = resultList data["Content"] = content data["Title"] = "Health Check" + execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) } +func execJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + // TaskStatus is a http.Handler with running task status (task name, status and the last execution). // it's in "/task" pattern in admin module. func taskStatus(rw http.ResponseWriter, req *http.Request) { From db547a7c84aa7957b65b84d33f92d79893d3c7ae Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 16:04:29 +0100 Subject: [PATCH 003/207] added test for execJson --- admin_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/admin_test.go b/admin_test.go index 71cc209e..b7a9af9c 100644 --- a/admin_test.go +++ b/admin_test.go @@ -1,7 +1,9 @@ package beego import ( + "encoding/json" "fmt" + "net/http/httptest" "testing" ) @@ -75,3 +77,27 @@ func oldMap() M { m["BConfig.Log.Outputs"] = BConfig.Log.Outputs return m } + +func TestExecJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + execJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Should be able to decode response body into decodedBody slice") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} From 8d1a9bc92e758e6e174076fcae16ccebb4bc7fba Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Mon, 6 Jul 2020 19:34:48 +0100 Subject: [PATCH 004/207] added tests for health check endpoints --- admin_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/admin_test.go b/admin_test.go index b7a9af9c..3875a4bb 100644 --- a/admin_test.go +++ b/admin_test.go @@ -2,11 +2,30 @@ package beego import ( "encoding/json" + "errors" "fmt" + "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/astaxie/beego/toolbox" ) +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + func TestList_01(t *testing.T) { m := make(M) list("BConfig", BConfig, m) @@ -101,3 +120,57 @@ func TestExecJSON(t *testing.T) { } } } + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + expectedResponseBody := `[{"message":"database","name":"success","status":"OK"},{"message":"cache","name":"error","status":"no cache detected"}]` + if w.Body.String() != expectedResponseBody { + t.Errorf("handler returned unexpected body: got %v want %v", + w.Body.String(), expectedResponseBody) + } +} From fc56c562dbadf1a4a35487bd036759d4acb90462 Mon Sep 17 00:00:00 2001 From: Gabriel Cruz Date: Mon, 6 Jul 2020 20:35:56 +0200 Subject: [PATCH 005/207] Fix logger reconnection --- logs/conn.go | 1 - logs/conn_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/logs/conn.go b/logs/conn.go index afe0cbb7..5201f30e 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -101,7 +101,6 @@ func (c *connWriter) connect() error { func (c *connWriter) needToConnectOnMsg() bool { if c.Reconnect { - c.Reconnect = false return true } diff --git a/logs/conn_test.go b/logs/conn_test.go index 747fb890..bb377d41 100644 --- a/logs/conn_test.go +++ b/logs/conn_test.go @@ -15,11 +15,65 @@ package logs import ( + "net" + "os" "testing" ) +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + func TestConn(t *testing.T) { log := NewLogger(1000) log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) log.Informational("informational") } + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} From d8724cb122327327784f182f23cba47817d5e1b8 Mon Sep 17 00:00:00 2001 From: Gabriel Cruz Date: Mon, 6 Jul 2020 21:34:09 +0200 Subject: [PATCH 006/207] Add error returning to writeln --- logs/conn.go | 5 ++++- logs/logger.go | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/logs/conn.go b/logs/conn.go index 5201f30e..74c458ab 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -63,7 +63,10 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { defer c.innerWriter.Close() } - c.lg.writeln(when, msg) + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } return nil } diff --git a/logs/logger.go b/logs/logger.go index c7cf8a56..a28bff6f 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -30,11 +30,12 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(when time.Time, msg string) { +func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { lg.Lock() h, _, _ := formatTimeHeader(when) - lg.writer.Write(append(append(h, msg...), '\n')) + n, err := lg.writer.Write(append(append(h, msg...), '\n')) lg.Unlock() + return n, err } const ( From 5a4a082af07dbecb996627e544e09fc32b097521 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 14:54:21 +0100 Subject: [PATCH 007/207] renamed functions for clarity --- admin.go | 22 +++++++++++----------- admin_test.go | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/admin.go b/admin.go index c5ae686d..cc9043e7 100644 --- a/admin.go +++ b/admin.go @@ -71,7 +71,7 @@ func init() { // AdminIndex is the default http.Handler for admin module. // it matches url pattern "/". func adminIndex(rw http.ResponseWriter, _ *http.Request) { - execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) } // QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. @@ -91,7 +91,7 @@ func qpsIndex(rw http.ResponseWriter, _ *http.Request) { } } - execTpl(rw, data, qpsTpl, defaultScriptsTpl) + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) } // ListConf is the http.Handler of displaying all beego configuration values as key/value pair. @@ -128,7 +128,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { } data["Content"] = content data["Title"] = "Routers" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) case "filter": var ( content = M{ @@ -171,7 +171,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { data["Content"] = content data["Title"] = "Filters" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) default: rw.Write([]byte("command not support")) } @@ -279,7 +279,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - execJSON(rw, dataJSON) + writeJSON(rw, dataJSON) return } @@ -288,7 +288,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { if command == "gc summary" { defaultTpl = gcAjaxTpl } - execTpl(rw, data, profillingTpl, defaultTpl) + writeTemplate(rw, data, profillingTpl, defaultTpl) } // Healthcheck is a http.Handler calling health checking and showing the result. @@ -340,7 +340,7 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } else { - execJSON(rw, JSONResponse) + writeJSON(rw, JSONResponse) } return @@ -350,10 +350,10 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { data["Content"] = content data["Title"] = "Health Check" - execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } -func execJSON(rw http.ResponseWriter, jsonData []byte) { +func writeJSON(rw http.ResponseWriter, jsonData []byte) { rw.Header().Set("Content-Type", "application/json") rw.Write(jsonData) } @@ -401,10 +401,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { content["Data"] = resultList data["Content"] = content data["Title"] = "Tasks" - execTpl(rw, data, tasksTpl, defaultScriptsTpl) + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) } -func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) for _, tpl := range tpls { tmpl = template.Must(tmpl.Parse(tpl)) diff --git a/admin_test.go b/admin_test.go index 3875a4bb..1bf8700a 100644 --- a/admin_test.go +++ b/admin_test.go @@ -97,7 +97,7 @@ func oldMap() M { return m } -func TestExecJSON(t *testing.T) { +func TestWriteJSON(t *testing.T) { t.Log("Testing the adding of JSON to the response") w := httptest.NewRecorder() @@ -105,7 +105,7 @@ func TestExecJSON(t *testing.T) { res, _ := json.Marshal(originalBody) - execJSON(w, res) + writeJSON(w, res) decodedBody := []int{} err := json.NewDecoder(w.Body).Decode(&decodedBody) From ca0c64b69e3357ff8ce94a0e52f445bdf471d4b5 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 15:21:38 +0100 Subject: [PATCH 008/207] refactored tests for health check endpoint --- admin_test.go | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/admin_test.go b/admin_test.go index 1bf8700a..4a614721 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -111,7 +112,7 @@ func TestWriteJSON(t *testing.T) { err := json.NewDecoder(w.Body).Decode(&decodedBody) if err != nil { - t.Fatal("Should be able to decode response body into decodedBody slice") + t.Fatal("Could not decode response body into slice.") } for i := range decodedBody { @@ -168,9 +169,31 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { status, http.StatusOK) } - expectedResponseBody := `[{"message":"database","name":"success","status":"OK"},{"message":"cache","name":"error","status":"no cache detected"}]` - if w.Body.String() != expectedResponseBody { + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { t.Errorf("handler returned unexpected body: got %v want %v", - w.Body.String(), expectedResponseBody) + decodedResponseBody, expectedResponseBody) } + } From 469dc7bea9b08add170177cc0a8615dbd0efa944 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:09:22 +0100 Subject: [PATCH 009/207] refactored the building of healthcheck response map --- admin.go | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/admin.go b/admin.go index cc9043e7..2ee415d4 100644 --- a/admin.go +++ b/admin.go @@ -21,6 +21,7 @@ import ( "net/http" "os" "reflect" + "strconv" "text/template" "time" @@ -321,28 +322,18 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { } queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) - if queryParams["json"] != nil { + if shouldReturnJSON { + responseMap := buildHealthCheckResponseMap(resultList) + jsonResponse, err := json.Marshal(responseMap) - type Result map[string]interface{} - - response := make([]Result, len(*resultList)) - - for i, currentResult := range *resultList { - currentResultMap := make(Result) - currentResultMap["name"] = currentResult[0] - currentResultMap["message"] = currentResult[1] - currentResultMap["status"] = currentResult[2] - response[i] = currentResultMap - } - - JSONResponse, err := json.Marshal(response) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } else { - writeJSON(rw, JSONResponse) + writeJSON(rw, jsonResponse) } - return } @@ -353,6 +344,23 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } +func buildHealthCheckResponseMap(resultList *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*resultList)) + + for i, currentResult := range *resultList { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = currentResult[0] + currentResultMap["message"] = currentResult[1] + currentResultMap["status"] = currentResult[2] + + response[i] = currentResultMap + } + + return response + +} + func writeJSON(rw http.ResponseWriter, jsonData []byte) { rw.Header().Set("Content-Type", "application/json") rw.Write(jsonData) From e0f8c6832d5477e7b239b60ce25db4b42f01ed49 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:28:16 +0100 Subject: [PATCH 010/207] added test for buildingHealthCheckResponse --- admin.go | 16 ++++++++-------- admin_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/admin.go b/admin.go index 2ee415d4..db52647e 100644 --- a/admin.go +++ b/admin.go @@ -326,8 +326,8 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) if shouldReturnJSON { - responseMap := buildHealthCheckResponseMap(resultList) - jsonResponse, err := json.Marshal(responseMap) + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) @@ -344,15 +344,15 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) } -func buildHealthCheckResponseMap(resultList *[][]string) []map[string]interface{} { - response := make([]map[string]interface{}, len(*resultList)) +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) - for i, currentResult := range *resultList { + for i, healthCheckResult := range *healthCheckResults { currentResultMap := make(map[string]interface{}) - currentResultMap["name"] = currentResult[0] - currentResultMap["message"] = currentResult[1] - currentResultMap["status"] = currentResult[2] + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] response[i] = currentResultMap } diff --git a/admin_test.go b/admin_test.go index 4a614721..24da3a2b 100644 --- a/admin_test.go +++ b/admin_test.go @@ -149,6 +149,41 @@ func TestHealthCheckHandlerDefault(t *testing.T) { } +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) From 728bf340064803d5ace628c07f715a2532c1311c Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 16:46:59 +0100 Subject: [PATCH 011/207] refacted cache health check from toolbox --- admin_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/admin_test.go b/admin_test.go index 24da3a2b..d9a66b34 100644 --- a/admin_test.go +++ b/admin_test.go @@ -187,7 +187,6 @@ func TestBuildHealthCheckResponseList(t *testing.T) { func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) if err != nil { @@ -213,11 +212,6 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { "message":"database", "name":"success", "status":"OK" - }, - { - "message":"cache", - "name":"error", - "status":"no cache detected" } ] `) From d7b0d55357dce068c8b216fc49d024071f865f42 Mon Sep 17 00:00:00 2001 From: Eyitayo Ogunbiyi Date: Tue, 7 Jul 2020 17:23:52 +0100 Subject: [PATCH 012/207] added extra check for same response lengths --- admin_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/admin_test.go b/admin_test.go index d9a66b34..3f3612e4 100644 --- a/admin_test.go +++ b/admin_test.go @@ -187,6 +187,7 @@ func TestBuildHealthCheckResponseList(t *testing.T) { func TestHealthCheckHandlerReturnsJSON(t *testing.T) { toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) if err != nil { @@ -212,6 +213,11 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { "message":"database", "name":"success", "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" } ] `) @@ -220,6 +226,11 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { t.Errorf("handler returned unexpected body: got %v want %v", decodedResponseBody, expectedResponseBody) From 946a42c021f3688ca411315b71fe989152df7f27 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Wed, 8 Jul 2020 17:14:52 +0800 Subject: [PATCH 013/207] fix: response http 413 when body size larger then MaxMemory. --- router.go | 4 ++++ router_test.go | 29 +++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index e71366b4..9b257391 100644 --- a/router.go +++ b/router.go @@ -707,6 +707,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if r.Method != http.MethodGet && r.Method != http.MethodHead { if BConfig.CopyRequestBody && !context.Input.IsUpload() { + if r.ContentLength > BConfig.MaxMemory { + exception("413", context) + goto Admin + } context.Input.CopyBody(BConfig.MaxMemory) } context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) diff --git a/router_test.go b/router_test.go index 2797b33a..e49f38db 100644 --- a/router_test.go +++ b/router_test.go @@ -15,6 +15,7 @@ package beego import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -71,7 +72,6 @@ func (tc *TestController) GetEmptyBody() { tc.Ctx.Output.Body(res) } - type JSONController struct { Controller } @@ -656,17 +656,14 @@ func beegoBeforeRouter1(ctx *context.Context) { ctx.WriteString("|BeforeRouter1") } - func beegoBeforeExec1(ctx *context.Context) { ctx.WriteString("|BeforeExec1") } - func beegoAfterExec1(ctx *context.Context) { ctx.WriteString("|AfterExec1") } - func beegoFinishRouter1(ctx *context.Context) { ctx.WriteString("|FinishRouter1") } @@ -709,3 +706,27 @@ func TestYAMLPrepare(t *testing.T) { t.Errorf(w.Body.String()) } } + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != 413 { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} From 03f78b2e4a5353173d96cbfbe19a3195cdc032f4 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Wed, 8 Jul 2020 18:09:01 +0800 Subject: [PATCH 014/207] fix: add error code support --- error.go | 15 ++++++++++++++- hooks.go | 1 + 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/error.go b/error.go index e5e9fd47..0b148974 100644 --- a/error.go +++ b/error.go @@ -28,7 +28,7 @@ import ( ) const ( - errorTypeHandler = iota + errorTypeHandler = iota errorTypeController ) @@ -359,6 +359,19 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { ) } +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The request entity is larger than limits defined by server"+ + "
    Please change the request entity and try again."+ + "
", + ) +} + func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := M{ diff --git a/hooks.go b/hooks.go index b8671d35..49c42d5a 100644 --- a/hooks.go +++ b/hooks.go @@ -34,6 +34,7 @@ func registerDefaultErrorHandler() error { "504": gatewayTimeout, "417": invalidxsrf, "422": missingxsrf, + "413": payloadTooLarge, } for e, h := range m { if _, ok := ErrorMaps[e]; !ok { From c08b27111ca04110da9e92dd5541c25bf712cc47 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 8 Jul 2020 22:50:03 +0800 Subject: [PATCH 015/207] Fix 4059 --- orm/db_alias.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index cf6a5935..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -425,7 +425,6 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { type stmtDecorator struct { wg sync.WaitGroup - lastUse int64 stmt *sql.Stmt } @@ -433,9 +432,12 @@ func (s *stmtDecorator) getStmt() *sql.Stmt { return s.stmt } +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this func (s *stmtDecorator) acquire() { s.wg.Add(1) - s.lastUse = time.Now().Unix() } func (s *stmtDecorator) release() { @@ -453,7 +455,6 @@ func (s *stmtDecorator) destroy() { func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { return &stmtDecorator{ stmt: sqlStmt, - lastUse: time.Now().Unix(), } } From 2eccb23461ccaa545564c865de6191ac2c9dc9e6 Mon Sep 17 00:00:00 2001 From: Cathal Date: Thu, 2 Jul 2020 20:48:43 +0100 Subject: [PATCH 016/207] Add sleep on reconnect functionality --- httplib/httplib.go | 8 ++++++++ httplib/httplib_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/httplib/httplib.go b/httplib/httplib.go index e094a6a6..60aa4e8b 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -144,6 +144,7 @@ type BeegoHTTPSettings struct { Gzip bool DumpBody bool Retries int // if set to -1 means will retry forever + RetryDelay time.Duration } // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. @@ -202,6 +203,11 @@ func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { return b } +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + // DumpBody setting whether need to Dump the Body. func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump @@ -512,11 +518,13 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { break } + time.Sleep(b.setting.RetryDelay) } return resp, err } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index dd2a4f1c..f6be8571 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "errors" "io/ioutil" "net" "net/http" @@ -33,6 +34,34 @@ func TestResponse(t *testing.T) { t.Log(resp) } +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + func TestGet(t *testing.T) { req := Get("http://httpbin.org/get") b, err := req.Bytes() From c3f14a0ad6ee54f0e474a220a8eb81a67a6b6335 Mon Sep 17 00:00:00 2001 From: Chenrui <631807682@qq.com> Date: Thu, 9 Jul 2020 09:45:40 +0800 Subject: [PATCH 017/207] refactor: log error when payload too large --- error.go | 13 +++++++------ router.go | 2 ++ router_test.go | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/error.go b/error.go index 0b148974..f268f723 100644 --- a/error.go +++ b/error.go @@ -363,12 +363,13 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { responseError(rw, r, 413, - "
The page you have requested is unavailable."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The request entity is larger than limits defined by server"+ - "
    Please change the request entity and try again."+ - "
", + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, ) } diff --git a/router.go b/router.go index 9b257391..d910deb0 100644 --- a/router.go +++ b/router.go @@ -707,7 +707,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if r.Method != http.MethodGet && r.Method != http.MethodHead { if BConfig.CopyRequestBody && !context.Input.IsUpload() { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > BConfig.MaxMemory { + logs.Error(errors.New("payload too large")) exception("413", context) goto Admin } diff --git a/router_test.go b/router_test.go index e49f38db..8ec7927a 100644 --- a/router_test.go +++ b/router_test.go @@ -726,7 +726,7 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { BConfig.CopyRequestBody = _CopyRequestBody BConfig.MaxMemory = _MaxMemory - if w.Code != 413 { + if w.Code != http.StatusRequestEntityTooLarge { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } From 76debb1899c10ac69356497d543f854c6a5dfa9f Mon Sep 17 00:00:00 2001 From: Acmefocus <37472851+Acmefocus@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:18:01 +0800 Subject: [PATCH 018/207] Update README.md --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 3b414c6f..aacd237e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego +#### Create hello directory + + sudo mkdir hello + cd hello + +#### Init module + + go mod init + #### Create file `hello.go` ```go package main From 40cdc877b618279ba598fb550a64c2b0f0325f53 Mon Sep 17 00:00:00 2001 From: Acmefocus <37472851+Acmefocus@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:18:01 +0800 Subject: [PATCH 019/207] Update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 3b414c6f..aacd237e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego +#### Create hello directory + + sudo mkdir hello + cd hello + +#### Init module + + go mod init + #### Create file `hello.go` ```go package main From 25ba78ea7247a6d05356a3012f7f0df7ddc92633 Mon Sep 17 00:00:00 2001 From: Acmefocus <107723772@qq.com> Date: Thu, 9 Jul 2020 18:14:31 +0800 Subject: [PATCH 020/207] update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index aacd237e..a9acfcc1 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature go get github.com/astaxie/beego -#### Create hello directory +#### Create `hello` directory, cd `hello` directory - sudo mkdir hello + mkdir hello cd hello - + #### Init module go mod init From 2b9aaa5b0d45a6a0a1009eea5575c3d62e8f4210 Mon Sep 17 00:00:00 2001 From: Acmefocus <107723772@qq.com> Date: Fri, 10 Jul 2020 09:58:06 +0800 Subject: [PATCH 021/207] update README.md Signed-off-by: Acmefocus <107723772@qq.com> --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a9acfcc1..de8f0063 100644 --- a/README.md +++ b/README.md @@ -8,18 +8,18 @@ It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific feature ## Quick Start -#### Download and install - - go get github.com/astaxie/beego - #### Create `hello` directory, cd `hello` directory - mkdir hello - cd hello + mkdir hello + cd hello #### Init module - go mod init + go mod init + +#### Download and install + + go get github.com/astaxie/beego #### Create file `hello.go` ```go From 5940ae33c2174806189419ec35b90a63454f649e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 13 Jul 2020 19:14:53 +0800 Subject: [PATCH 022/207] fix `index out of range` when sid len = 1 add unit test for sess_file.go --- session/sess_file.go | 6 +- session/sess_file_test.go | 387 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 session/sess_file_test.go diff --git a/session/sess_file.go b/session/sess_file.go index c6dbf209..f7a739cc 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -15,11 +15,11 @@ package session import ( + "errors" "fmt" "io/ioutil" "net/http" "os" - "errors" "path" "path/filepath" "strings" @@ -180,6 +180,10 @@ func (fp *FileProvider) SessionExist(sid string) bool { filepder.lock.Lock() defer filepder.lock.Unlock() + if len(sid) < 2 { + return false + } + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) return err == nil } diff --git a/session/sess_file_test.go b/session/sess_file_test.go new file mode 100644 index 00000000..0cf021db --- /dev/null +++ b/session/sess_file_test.go @@ -0,0 +1,387 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + + s.Set(i,i) + s.SessionRelease(nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(i).(int) != i { + t.Error() + } + } +} \ No newline at end of file From 678b90385b36334cb77cffe50e965b0ff13e7ee7 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 14 Jul 2020 09:57:13 +0800 Subject: [PATCH 023/207] add log --- session/sess_file.go | 1 + 1 file changed, 1 insertion(+) diff --git a/session/sess_file.go b/session/sess_file.go index f7a739cc..47ad54a7 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -181,6 +181,7 @@ func (fp *FileProvider) SessionExist(sid string) bool { defer filepder.lock.Unlock() if len(sid) < 2 { + SLogger.Println("min length of session id is 2", sid) return false } From ffe1d5212009bae069acfe3a2d02954c7acb1756 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 15 Jul 2020 10:04:22 +0800 Subject: [PATCH 024/207] Move orm to pkg/orm --- pkg/orm/README.md | 159 +++ pkg/orm/cmd.go | 283 +++++ pkg/orm/cmd_utils.go | 320 +++++ pkg/orm/db.go | 1902 +++++++++++++++++++++++++++++ pkg/orm/db_alias.go | 466 +++++++ pkg/orm/db_mysql.go | 183 +++ pkg/orm/db_oracle.go | 137 +++ pkg/orm/db_postgres.go | 189 +++ pkg/orm/db_sqlite.go | 161 +++ pkg/orm/db_tables.go | 482 ++++++++ pkg/orm/db_tidb.go | 63 + pkg/orm/db_utils.go | 177 +++ pkg/orm/models.go | 99 ++ pkg/orm/models_boot.go | 347 ++++++ pkg/orm/models_fields.go | 783 ++++++++++++ pkg/orm/models_info_f.go | 473 ++++++++ pkg/orm/models_info_m.go | 148 +++ pkg/orm/models_test.go | 497 ++++++++ pkg/orm/models_utils.go | 227 ++++ pkg/orm/orm.go | 579 +++++++++ pkg/orm/orm_conds.go | 153 +++ pkg/orm/orm_log.go | 222 ++++ pkg/orm/orm_object.go | 87 ++ pkg/orm/orm_querym2m.go | 140 +++ pkg/orm/orm_queryset.go | 300 +++++ pkg/orm/orm_raw.go | 867 +++++++++++++ pkg/orm/orm_test.go | 2494 ++++++++++++++++++++++++++++++++++++++ pkg/orm/qb.go | 62 + pkg/orm/qb_mysql.go | 185 +++ pkg/orm/qb_tidb.go | 182 +++ pkg/orm/types.go | 473 ++++++++ pkg/orm/utils.go | 319 +++++ pkg/orm/utils_test.go | 70 ++ 33 files changed, 13229 insertions(+) create mode 100644 pkg/orm/README.md create mode 100644 pkg/orm/cmd.go create mode 100644 pkg/orm/cmd_utils.go create mode 100644 pkg/orm/db.go create mode 100644 pkg/orm/db_alias.go create mode 100644 pkg/orm/db_mysql.go create mode 100644 pkg/orm/db_oracle.go create mode 100644 pkg/orm/db_postgres.go create mode 100644 pkg/orm/db_sqlite.go create mode 100644 pkg/orm/db_tables.go create mode 100644 pkg/orm/db_tidb.go create mode 100644 pkg/orm/db_utils.go create mode 100644 pkg/orm/models.go create mode 100644 pkg/orm/models_boot.go create mode 100644 pkg/orm/models_fields.go create mode 100644 pkg/orm/models_info_f.go create mode 100644 pkg/orm/models_info_m.go create mode 100644 pkg/orm/models_test.go create mode 100644 pkg/orm/models_utils.go create mode 100644 pkg/orm/orm.go create mode 100644 pkg/orm/orm_conds.go create mode 100644 pkg/orm/orm_log.go create mode 100644 pkg/orm/orm_object.go create mode 100644 pkg/orm/orm_querym2m.go create mode 100644 pkg/orm/orm_queryset.go create mode 100644 pkg/orm/orm_raw.go create mode 100644 pkg/orm/orm_test.go create mode 100644 pkg/orm/qb.go create mode 100644 pkg/orm/qb_mysql.go create mode 100644 pkg/orm/qb_tidb.go create mode 100644 pkg/orm/types.go create mode 100644 pkg/orm/utils.go create mode 100644 pkg/orm/utils_test.go diff --git a/pkg/orm/README.md b/pkg/orm/README.md new file mode 100644 index 00000000..6e808d2a --- /dev/null +++ b/pkg/orm/README.md @@ -0,0 +1,159 @@ +# beego orm + +[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) + +A powerful orm framework for go. + +It is heavily influenced by Django ORM, SQLAlchemy. + +**Support Database:** + +* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) +* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +Passed all test, but need more feedback. + +**Features:** + +* full go type support +* easy for usage, simple CRUD operation +* auto join with relation table +* cross DataBase compatible query +* Raw SQL query / mapper without orm model +* full test keep stable and strong + +more features please read the docs + +**Install:** + + go get github.com/astaxie/beego/orm + +## Changelog + +* 2013-08-19: support table auto create +* 2013-08-13: update test for database types +* 2013-08-13: go type support, such as int8, uint8, byte, rune +* 2013-08-13: date / datetime timezone support very well + +## Quick Start + +#### Simple Usage + +```go +package main + +import ( + "fmt" + "github.com/astaxie/beego/orm" + _ "github.com/go-sql-driver/mysql" // import your used driver +) + +// Model Struct +type User struct { + Id int `orm:"auto"` + Name string `orm:"size(100)"` +} + +func init() { + // register model + orm.RegisterModel(new(User)) + + // set default database + orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) + + // create table + orm.RunSyncdb("default", false, true) +} + +func main() { + o := orm.NewOrm() + + user := User{Name: "slene"} + + // insert + id, err := o.Insert(&user) + + // update + user.Name = "astaxie" + num, err := o.Update(&user) + + // read one + u := User{Id: user.Id} + err = o.Read(&u) + + // delete + num, err = o.Delete(&u) +} +``` + +#### Next with relation + +```go +type Post struct { + Id int `orm:"auto"` + Title string `orm:"size(100)"` + User *User `orm:"rel(fk)"` +} + +var posts []*Post +qs := o.QueryTable("post") +num, err := qs.Filter("User__Name", "slene").All(&posts) +``` + +#### Use Raw sql + +If you don't like ORM,use Raw SQL to query / mapping without ORM setting + +```go +var maps []Params +num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) +if num > 0 { + fmt.Println(maps[0]["id"]) +} +``` + +#### Transaction + +```go +o.Begin() +... +user := User{Name: "slene"} +id, err := o.Insert(&user) +if err == nil { + o.Commit() +} else { + o.Rollback() +} + +``` + +#### Debug Log Queries + +In development env, you can simple use + +```go +func main() { + orm.Debug = true +... +``` + +enable log queries. + +output include all queries, such as exec / prepare / transaction. + +like this: + +```go +[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` +... +``` + +note: not recommend use this in product env. + +## Docs + +more details and examples in docs and test + +[documents](http://beego.me/docs/mvc/model/overview.md) + diff --git a/pkg/orm/cmd.go b/pkg/orm/cmd.go new file mode 100644 index 00000000..0ff4dc40 --- /dev/null +++ b/pkg/orm/cmd.go @@ -0,0 +1,283 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "flag" + "fmt" + "os" + "strings" +) + +type commander interface { + Parse([]string) + Run() error +} + +var ( + commands = make(map[string]commander) +) + +// print help. +func printHelp(errs ...string) { + content := `orm command usage: + + syncdb - auto create tables + sqlall - print sql of create tables + help - print this help +` + + if len(errs) > 0 { + fmt.Println(errs[0]) + } + fmt.Println(content) + os.Exit(2) +} + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + if len(os.Args) < 2 || os.Args[1] != "orm" { + return + } + + BootStrap() + + args := argString(os.Args[2:]) + name := args.Get(0) + + if name == "help" { + printHelp() + } + + if cmd, ok := commands[name]; ok { + cmd.Parse(os.Args[3:]) + cmd.Run() + os.Exit(0) + } else { + if name == "" { + printHelp() + } else { + printHelp(fmt.Sprintf("unknown command %s", name)) + } + } +} + +// sync database struct command interface. +type commandSyncDb struct { + al *alias + force bool + verbose bool + noInfo bool + rtOnError bool +} + +// parse orm command line arguments. +func (d *commandSyncDb) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.BoolVar(&d.force, "force", false, "drop tables before create") + flagSet.BoolVar(&d.verbose, "v", false, "verbose info") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSyncDb) Run() error { + var drops []string + if d.force { + drops = getDbDropSQL(d.al) + } + + db := d.al.DB + + if d.force { + for i, mi := range modelCache.allOrdered() { + query := drops[i] + if !d.noInfo { + fmt.Printf("drop table `%s`\n", mi.table) + } + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + sqls, indexes := getDbCreateSQL(d.al) + + tables, err := d.al.DbBaser.GetTables(db) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for i, mi := range modelCache.allOrdered() { + if tables[mi.table] { + if !d.noInfo { + fmt.Printf("table `%s` already exists, skip\n", mi.table) + } + + var fields []*fieldInfo + columns, err := d.al.DbBaser.GetColumns(db, mi.table) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for _, fi := range mi.fields.fieldsDB { + if _, ok := columns[fi.column]; !ok { + fields = append(fields, fi) + } + } + + for _, fi := range fields { + query := getColumnAddQuery(d.al, fi) + + if !d.noInfo { + fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + } + + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + + for _, idx := range indexes[mi.table] { + if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.noInfo { + fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) + } + + query := idx.SQL + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + continue + } + + if !d.noInfo { + fmt.Printf("create table `%s` \n", mi.table) + } + + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + + for _, query := range queries { + _, err := db.Exec(query) + if d.verbose { + query = " " + strings.Join(strings.Split(query, "\n"), "\n ") + fmt.Println(query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + if d.verbose { + fmt.Println("") + } + } + + return nil +} + +// database creation commander interface implement. +type commandSQLAll struct { + al *alias +} + +// parse orm command line arguments. +func (d *commandSQLAll) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSQLAll) Run() error { + sqls, indexes := getDbCreateSQL(d.al) + var all []string + for i, mi := range modelCache.allOrdered() { + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + sql := strings.Join(queries, "\n") + all = append(all, sql) + } + fmt.Println(strings.Join(all, "\n\n")) + + return nil +} + +func init() { + commands["syncdb"] = new(commandSyncDb) + commands["sqlall"] = new(commandSQLAll) +} + +// RunSyncdb run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. +func RunSyncdb(name string, force bool, verbose bool) error { + BootStrap() + + al := getDbAlias(name) + cmd := new(commandSyncDb) + cmd.al = al + cmd.force = force + cmd.noInfo = !verbose + cmd.verbose = verbose + cmd.rtOnError = true + return cmd.Run() +} diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go new file mode 100644 index 00000000..61f17346 --- /dev/null +++ b/pkg/orm/cmd_utils.go @@ -0,0 +1,320 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "strings" +) + +type dbIndex struct { + Table string + Name string + SQL string +} + +// create database drop sql. +func getDbDropSQL(al *alias) (sqls []string) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return sqls +} + +// get database column type string. +func getColumnTyp(al *alias, fi *fieldInfo) (col string) { + T := al.DbBaser.DbTypes() + fieldType := fi.fieldType + fieldSize := fi.size + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeVarCharField: + if al.Driver == DRPostgres && fi.toText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } + case TypeCharField: + col = fmt.Sprintf(T["string-char"], fieldSize) + case TypeTextField: + col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + col = T["time.Time"] + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if al.Driver == DRSqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if !strings.Contains(s, "%d") { + col = s + } else { + col = fmt.Sprintf(s, fi.digits, fi.decimals) + } + case TypeJSONField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["jsonb"] + case RelForeignKey, RelOneToOne: + fieldType = fi.relModelInfo.fields.pk.fieldType + fieldSize = fi.relModelInfo.fields.pk.size + goto checkColumn + } + + return +} + +// create alter sql string. +func getColumnAddQuery(al *alias, fi *fieldInfo) string { + Q := al.DbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if !fi.null { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.mi.table, Q, + Q, fi.column, Q, + typ, getColumnDefault(fi), + ) +} + +// create database creation string. +func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + sqls = append(sqls, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + +// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands +func getColumnDefault(fi *fieldInfo) string { + var ( + v, t, d string + ) + + // Skip default attribute if field is in relations + if fi.rel || fi.reverse { + return v + } + + t = " DEFAULT '%s' " + + // These defaults will be useful if there no config value orm:"default" and NOT NULL is on + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, + TypeDecimalField: + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" + } + + if fi.colDefault { + if !fi.initial.Exist() { + v = fmt.Sprintf(t, "") + } else { + v = fmt.Sprintf(t, fi.initial.String()) + } + } else { + if !fi.null { + v = fmt.Sprintf(t, d) + } + } + + return v +} diff --git a/pkg/orm/db.go b/pkg/orm/db.go new file mode 100644 index 00000000..9a1827e8 --- /dev/null +++ b/pkg/orm/db.go @@ -0,0 +1,1902 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = errors.New("missed pk value") +) + +var ( + operators = map[string]bool{ + "exact": true, + "iexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "eq": true, + "nq": true, + "ne": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + "between": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, + } +) + +// an instance of dbBaser interface/ +type dbBase struct { + ins dbBaser +} + +// check dbBase implements dbBaser interface. +var _ dbBaser = new(dbBase) + +// get struct columns values as interface slice. +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { + if names == nil { + ns := make([]string, 0, len(cols)) + names = &ns + } + values = make([]interface{}, 0, len(cols)) + + for _, column := range cols { + var fi *fieldInfo + if fi, _ = mi.fields.GetByAny(column); fi != nil { + column = fi.column + } else { + panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + } + if !fi.dbcol || fi.auto && skipAuto { + continue + } + value, err := d.collectFieldValue(mi, fi, ind, insert, tz) + if err != nil { + return nil, nil, err + } + + // ignore empty value auto field + if insert && fi.auto { + if fi.fieldType&IsPositiveIntegerField > 0 { + if vu, ok := value.(uint64); !ok || vu == 0 { + continue + } + } else { + if vu, ok := value.(int64); !ok || vu == 0 { + continue + } + } + autoFields = append(autoFields, fi.column) + } + + *names, values = append(*names, column), append(values, value) + } + + return +} + +// get one field value in struct column as interface. +func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { + var value interface{} + if fi.pk { + _, value, _ = getExistPk(mi, ind) + } else { + field := ind.FieldByIndex(fi.fieldIndex) + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + value = f.RawValue() + } else { + switch fi.fieldType { + case TypeBooleanField: + if nb, ok := field.Interface().(sql.NullBool); ok { + value = nil + if nb.Valid { + value = nb.Bool + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Bool() + } + } else { + value = field.Bool() + } + case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: + if ns, ok := field.Interface().(sql.NullString); ok { + value = nil + if ns.Valid { + value = ns.String + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().String() + } + } else { + value = field.String() + } + case TypeFloatField, TypeDecimalField: + if nf, ok := field.Interface().(sql.NullFloat64); ok { + value = nil + if nf.Valid { + value = nf.Float64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Float() + } + } else { + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } + } + case TypeTimeField, TypeDateField, TypeDateTimeField: + value = field.Interface() + if t, ok := value.(time.Time); ok { + d.ins.TimeToDB(&t, tz) + if t.IsZero() { + value = nil + } else { + value = t + } + } + default: + switch { + case fi.fieldType&IsPositiveIntegerField > 0: + if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Uint() + } + } else { + value = field.Uint() + } + case fi.fieldType&IsIntegerField > 0: + if ni, ok := field.Interface().(sql.NullInt64); ok { + value = nil + if ni.Valid { + value = ni.Int64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Int() + } + } else { + value = field.Int() + } + case fi.fieldType&IsRelField > 0: + if field.IsNil() { + value = nil + } else { + if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { + value = vu + } else { + value = nil + } + } + if !fi.null && value == nil { + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) + } + } + } + } + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField: + if fi.autoNow || fi.autoNowAdd && insert { + if insert { + if t, ok := value.(time.Time); ok && !t.IsZero() { + break + } + } + tnow := time.Now() + d.ins.TimeToDB(&tnow, tz) + value = tnow + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + f.SetRaw(tnow.In(DefaultTimeLoc)) + } else if field.Kind() == reflect.Ptr { + v := tnow.In(DefaultTimeLoc) + field.Set(reflect.ValueOf(&v)) + } else { + field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) + } + } + case TypeJSONField, TypeJsonbField: + if s, ok := value.(string); (ok && len(s) == 0) || value == nil { + if fi.colDefault && fi.initial.Exist() { + value = fi.initial.String() + } else { + value = nil + } + } + } + } + return value, nil +} + +// create insert sql preparation statement object. +func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { + Q := d.ins.TableQuote() + + dbcols := make([]string, 0, len(mi.fields.dbcols)) + marks := make([]string, 0, len(mi.fields.dbcols)) + for _, fi := range mi.fields.fieldsDB { + if !fi.auto { + dbcols = append(dbcols, fi.column) + marks = append(marks, "?") + } + } + qmarks := strings.Join(marks, ", ") + sep := fmt.Sprintf("%s, %s", Q, Q) + columns := strings.Join(dbcols, sep) + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + d.ins.HasReturningID(mi, &query) + + stmt, err := q.Prepare(query) + return stmt, query, err +} + +// insert struct with prepared statement and given struct reflect value. +func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + if err != nil { + return 0, err + } + + if d.ins.HasReturningID(mi, nil) { + row := stmt.QueryRow(values...) + var id int64 + err := row.Scan(&id) + return id, err + } + res, err := stmt.Exec(values...) + if err == nil { + return res.LastInsertId() + } + return 0, err +} + +// query sql ,read records and persist in dbBaser. +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + var whereCols []string + var args []interface{} + + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s, %s", Q, Q) + sels := strings.Join(mi.fields.dbcols, sep) + colsNum := len(mi.fields.dbcols) + + sep = fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + forUpdate := "" + if isForUpdate { + forUpdate = "FOR UPDATE" + } + + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + d.ins.ReplaceMarks(&query) + + row := q.QueryRow(query, args...) + if err := row.Scan(refs...); err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + ind.Set(mind) + return nil +} + +// execute insert sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names := make([]string, 0, len(mi.fields.dbcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return 0, err + } + + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err +} + +// multi-insert sql with given slice struct reflect.Value. +func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { + var ( + cnt int64 + nums int + values []interface{} + names []string + ) + + // typ := reflect.Indirect(mi.addrField).Type() + + length, autoFields := sind.Len(), make([]string, 0, 1) + + for i := 1; i <= length; i++ { + + ind := reflect.Indirect(sind.Index(i - 1)) + + // Is this needed ? + // if !ind.Type().AssignableTo(typ) { + // return cnt, ErrArgs + // } + + if i == 1 { + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return cnt, err + } + values = make([]interface{}, bulk*len(vus)) + nums += copy(values, vus) + } else { + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + if err != nil { + return cnt, err + } + + if len(vus) != len(names) { + return cnt, ErrArgs + } + + nums += copy(values[nums:], vus) + } + + if i > 1 && i%bulk == 0 || length == i { + num, err := d.InsertValue(q, mi, true, names, values[:nums]) + if err != nil { + return cnt, err + } + cnt += num + nums = 0 + } + } + + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = "?" + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti && multi > 1 { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + args0 := "" + iouStr := "" + argsMap := map[string]string{} + switch a.Driver { + case DRMySQL: + iouStr = "ON DUPLICATE KEY UPDATE" + case DRPostgres: + if len(args) == 0 { + return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) + } + args0 = strings.ToLower(args[0]) + iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) + default: + return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) + } + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + var conflitValue interface{} + for i, v := range names { + // identifier in database may not be case-sensitive, so quote it + v = fmt.Sprintf("%s%s%s", Q, v, Q) + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if v == args0 { + conflitValue = values[i] + } + if valueStr != "" { + switch a.Driver { + case DRMySQL: + updates[i] = v + "=" + valueStr + case DRPostgres: + if conflitValue != nil { + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) + updateValues = append(updateValues, conflitValue) + } else { + return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) + } + } + } else { + updates[i] = v + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + if err != nil && err.Error() == `pq: syntax error at or near "ON"` { + err = fmt.Errorf("postgres version must 9.5 or higher") + } + return id, err +} + +// execute update sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + pkName, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + + var setNames []string + + // if specify cols length is zero, then commit all columns. + if len(cols) == 0 { + cols = mi.fields.dbcols + setNames = make([]string, 0, len(mi.fields.dbcols)-1) + } else { + setNames = make([]string, 0, len(cols)) + } + + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + if err != nil { + return 0, err + } + + var findAutoNowAdd, findAutoNow bool + var index int + for i, col := range setNames { + if mi.fields.GetByColumn(col).autoNowAdd { + index = i + findAutoNowAdd = true + } + if mi.fields.GetByColumn(col).autoNow { + findAutoNow = true + } + } + if findAutoNowAdd { + setNames = append(setNames[0:index], setNames[index+1:]...) + setValues = append(setValues[0:index], setValues[index+1:]...) + } + + if !findAutoNow { + for col, info := range mi.fields.columns { + if info.autoNow { + setNames = append(setNames, col) + setValues = append(setValues, time.Now()) + } + } + } + + setValues = append(setValues, pkValue) + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ?, %s", Q, Q) + setColumns := strings.Join(setNames, sep) + + query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + + d.ins.ReplaceMarks(&query) + + res, err := q.Exec(query, setValues...) + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + var whereCols []string + var args []interface{} + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return 0, err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) + + d.ins.ReplaceMarks(&query) + res, err := q.Exec(query, args...) + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) + } + } + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, err + } + return 0, err +} + +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. +func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { + columns := make([]string, 0, len(params)) + values := make([]interface{}, 0, len(params)) + for col, val := range params { + if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { + panic(fmt.Errorf("wrong field/column name `%s`", col)) + } else { + columns = append(columns, fi.column) + values = append(values, val) + } + } + + if len(columns) == 0 { + panic(fmt.Errorf("update params cannot empty")) + } + + tables := newDbTables(mi, d.ins) + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + where, args := tables.getCondSQL(cond, false, tz) + + values = append(values, args...) + + join := tables.getJoinSQL() + + var query, T string + + Q := d.ins.TableQuote() + + if d.ins.SupportUpdateJoin() { + T = "T0." + } + + cols := make([]string, 0, len(columns)) + + for i, v := range columns { + col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) + if c, ok := values[i].(colValue); ok { + switch c.opt { + case ColAdd: + cols = append(cols, col+" = "+col+" + ?") + case ColMinus: + cols = append(cols, col+" = "+col+" - ?") + case ColMultiply: + cols = append(cols, col+" = "+col+" * ?") + case ColExcept: + cols = append(cols, col+" = "+col+" / ?") + case ColBitAnd: + cols = append(cols, col+" = "+col+" & ?") + case ColBitRShift: + cols = append(cols, col+" = "+col+" >> ?") + case ColBitLShift: + cols = append(cols, col+" = "+col+" << ?") + case ColBitXOR: + cols = append(cols, col+" = "+col+" ^ ?") + case ColBitOr: + cols = append(cols, col+" = "+col+" | ?") + } + values[i] = c.value + } else { + cols = append(cols, col+" = ?") + } + } + + sets := strings.Join(cols, ", ") + " " + + if d.ins.SupportUpdateJoin() { + query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) + } else { + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) + } + + d.ins.ReplaceMarks(&query) + var err error + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, values...) + } else { + res, err = q.Exec(query, values...) + } + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. +func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { + for _, fi := range mi.fields.fieldsReverse { + fi = fi.reverseFieldInfo + switch fi.onDelete { + case odCascade: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + if err != nil { + return err + } + case odSetDefault, odSetNULL: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + params := Params{fi.column: nil} + if fi.onDelete == odSetDefault { + params[fi.column] = fi.initial.String() + } + _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + if err != nil { + return err + } + case odDoNothing: + } + } + return nil +} + +// delete table-related records. +func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { + tables := newDbTables(mi, d.ins) + tables.skipEnd = true + + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + if cond == nil || cond.IsEmpty() { + panic(fmt.Errorf("delete operation cannot execute without condition")) + } + + Q := d.ins.TableQuote() + + where, args := tables.getCondSQL(cond, false, tz) + join := tables.getJoinSQL() + + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + r, err := q.Query(query, args...) + if err != nil { + return 0, err + } + rs = r + defer rs.Close() + + var ref interface{} + args = make([]interface{}, 0) + cnt := 0 + for rs.Next() { + if err := rs.Scan(&ref); err != nil { + return 0, err + } + pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + if err != nil { + return 0, err + } + args = append(args, pkValue) + cnt++ + } + + if cnt == 0 { + return 0, nil + } + + marks := make([]string, len(args)) + for i := range marks { + marks[i] = "?" + } + sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) + + d.ins.ReplaceMarks(&query) + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, args...) + } else { + res, err = q.Exec(query, args...) + } + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, nil + } + return 0, err +} + +// read related records. +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { + + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + errTyp := true + one := true + isPtr := true + + if val.Kind() == reflect.Ptr { + fn := "" + if ind.Kind() == reflect.Slice { + one = false + typ := ind.Type().Elem() + switch typ.Kind() { + case reflect.Ptr: + fn = getFullName(typ.Elem()) + case reflect.Struct: + isPtr = false + fn = getFullName(typ) + } + } else { + fn = getFullName(ind.Type()) + } + errTyp = fn != mi.fullName + } + + if errTyp { + if one { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) + } else { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) + } + } + + rlimit := qs.limit + offset := qs.offset + + Q := d.ins.TableQuote() + + var tCols []string + if len(cols) > 0 { + hasRel := len(qs.related) > 0 || qs.relDepth > 0 + tCols = make([]string, 0, len(cols)) + var maps map[string]bool + if hasRel { + maps = make(map[string]bool) + } + for _, col := range cols { + if fi, ok := mi.fields.GetByAny(col); ok { + tCols = append(tCols, fi.column) + if hasRel { + maps[fi.column] = true + } + } else { + return 0, fmt.Errorf("wrong field/column name `%s`", col) + } + } + if hasRel { + for _, fi := range mi.fields.fieldsDB { + if fi.fieldType&IsRelField > 0 { + if !maps[fi.column] { + tCols = append(tCols, fi.column) + } + } + } + } + } else { + tCols = mi.fields.dbcols + } + + colsNum := len(tCols) + sep := fmt.Sprintf("%s, T0.%s", Q, Q) + sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) + + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, offset, rlimit) + join := tables.getJoinSQL() + + for _, tbl := range tables.tables { + if tbl.sel { + colsNum += len(tbl.mi.fields.dbcols) + sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) + sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + } + } + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + if qs.forupdate { + query += " FOR UPDATE" + } + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + var err error + if qs != nil && qs.forContext { + rs, err = q.QueryContext(qs.ctx, query, args...) + if err != nil { + return 0, err + } + } else { + rs, err = q.Query(query, args...) + if err != nil { + return 0, err + } + } + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + slice := ind + + var cnt int64 + for rs.Next() { + if one && cnt == 0 || !one { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + + cacheV := make(map[string]*reflect.Value) + cacheM := make(map[string]*modelInfo) + trefs := refs + + d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) + trefs = refs[len(tCols):] + + for _, tbl := range tables.tables { + // loop selected tables + if tbl.sel { + last := mind + names := "" + mmi := mi + // loop cascade models + for _, name := range tbl.names { + names += name + if val, ok := cacheV[names]; ok { + last = *val + mmi = cacheM[names] + } else { + fi := mmi.fields.GetByName(name) + lastm := mmi + mmi = fi.relModelInfo + field := last + if last.Kind() != reflect.Invalid { + field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) + if field.IsValid() { + d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) + for _, fi := range mmi.fields.fieldsReverse { + if fi.inModel && fi.reverseFieldInfo.mi == lastm { + if fi.reverseFieldInfo != nil { + f := field.FieldByIndex(fi.fieldIndex) + if f.Kind() == reflect.Ptr { + f.Set(last.Addr()) + } + } + } + } + last = field + } + } + cacheV[names] = &field + cacheM[names] = mmi + } + } + trefs = trefs[len(mmi.fields.dbcols):] + } + } + + if one { + ind.Set(mind) + } else { + if cnt == 0 { + // you can use a empty & caped container list + // orm will not replace it + if ind.Len() != 0 { + // if container is not empty + // create a new one + slice = reflect.New(ind.Type()).Elem() + } + } + + if isPtr { + slice = reflect.Append(slice, mind.Addr()) + } else { + slice = reflect.Append(slice, mind) + } + } + } + cnt++ + } + + if !one { + if cnt > 0 { + ind.Set(slice) + } else { + // when a result is empty and container is nil + // to set a empty container + if ind.IsNil() { + ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) + } + } + } + + return cnt, nil +} + +// excute count sql and return count result int64. +func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + tables.getOrderSQL(qs.orders) + join := tables.getJoinSQL() + + Q := d.ins.TableQuote() + + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + + if groupBy != "" { + query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) + } + + d.ins.ReplaceMarks(&query) + + var row *sql.Row + if qs != nil && qs.forContext { + row = q.QueryRowContext(qs.ctx, query, args...) + } else { + row = q.QueryRow(query, args...) + } + err = row.Scan(&cnt) + return +} + +// generate sql with replacing operator string placeholders and replaced values. +func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { + var sql string + params := getFlatParams(fi, args, tz) + + if len(params) == 0 { + panic(fmt.Errorf("operator `%s` need at least one args", operator)) + } + arg := params[0] + + switch operator { + case "in": + marks := make([]string, len(params)) + for i := range marks { + marks[i] = "?" + } + sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + case "between": + if len(params) != 2 { + panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) + } + sql = "BETWEEN ? AND ?" + default: + if len(params) > 1 { + panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) + } + sql = d.ins.OperatorSQL(operator) + switch operator { + case "exact": + if arg == nil { + params[0] = "IS NULL" + } + case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": + param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + switch operator { + case "iexact": + case "contains", "icontains": + param = fmt.Sprintf("%%%s%%", param) + case "startswith", "istartswith": + param = fmt.Sprintf("%s%%", param) + case "endswith", "iendswith": + param = fmt.Sprintf("%%%s", param) + } + params[0] = param + case "isnull": + if b, ok := arg.(bool); ok { + if b { + sql = "IS NULL" + } else { + sql = "IS NOT NULL" + } + params = nil + } else { + panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) + } + } + } + return sql, params +} + +// gernerate sql string with inner function, such as UPPER(text). +func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { + // default not use +} + +// set values to struct column. +func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { + for i, column := range cols { + val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() + + fi := mi.fields.GetByColumn(column) + + field := ind.FieldByIndex(fi.fieldIndex) + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + + _, err = d.setFieldValue(fi, value, field) + + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + } +} + +// convert value from database result to value following in field type. +func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { + if val == nil { + return nil, nil + } + + var value interface{} + var tErr error + + var str *StrTo + switch v := val.(type) { + case []byte: + s := StrTo(string(v)) + str = &s + case string: + s := StrTo(v) + str = &s + } + + fieldType := fi.fieldType + +setValue: + switch { + case fieldType == TypeBooleanField: + if str == nil { + switch v := val.(type) { + case int64: + b := v == 1 + value = b + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + b, err := str.Bool() + if err != nil { + tErr = err + goto end + } + value = b + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if str == nil { + value = ToStr(val) + } else { + value = str.String() + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if str == nil { + switch t := val.(type) { + case time.Time: + d.ins.TimeFromDB(&t, tz) + value = t + default: + s := StrTo(ToStr(t)) + str = &s + } + } + if str != nil { + s := str.String() + var ( + t time.Time + err error + ) + if len(s) >= 19 { + s = s[:19] + t, err = time.ParseInLocation(formatDateTime, s, tz) + } else if len(s) >= 10 { + if len(s) > 10 { + s = s[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else if len(s) >= 8 { + if len(s) > 8 { + s = s[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + t = t.In(DefaultTimeLoc) + + if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { + tErr = err + goto end + } + value = t + } + case fieldType&IsIntegerField > 0: + if str == nil { + s := StrTo(ToStr(val)) + str = &s + } + if str != nil { + var err error + switch fieldType { + case TypeBitField: + _, err = str.Int8() + case TypeSmallIntegerField: + _, err = str.Int16() + case TypeIntegerField: + _, err = str.Int32() + case TypeBigIntegerField: + _, err = str.Int64() + case TypePositiveBitField: + _, err = str.Uint8() + case TypePositiveSmallIntegerField: + _, err = str.Uint16() + case TypePositiveIntegerField: + _, err = str.Uint32() + case TypePositiveBigIntegerField: + _, err = str.Uint64() + } + if err != nil { + tErr = err + goto end + } + if fieldType&IsPositiveIntegerField > 0 { + v, _ := str.Uint64() + value = v + } else { + v, _ := str.Int64() + value = v + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if str == nil { + switch v := val.(type) { + case float64: + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + v, err := str.Float64() + if err != nil { + tErr = err + goto end + } + value = v + } + case fieldType&IsRelField > 0: + fi = fi.relModelInfo.fields.pk + fieldType = fi.fieldType + goto setValue + } + +end: + if tErr != nil { + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + return nil, err + } + + return value, nil + +} + +// set one value to struct column field. +func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { + + fieldType := fi.fieldType + isNative := !fi.isFielder + +setValue: + switch { + case fieldType == TypeBooleanField: + if isNative { + if nb, ok := field.Interface().(sql.NullBool); ok { + if value == nil { + nb.Valid = false + } else { + nb.Bool = value.(bool) + nb.Valid = true + } + field.Set(reflect.ValueOf(nb)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(bool) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = false + } + field.SetBool(value.(bool)) + } + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if isNative { + if ns, ok := field.Interface().(sql.NullString); ok { + if value == nil { + ns.Valid = false + } else { + ns.String = value.(string) + ns.Valid = true + } + field.Set(reflect.ValueOf(ns)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(string) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = "" + } + field.SetString(value.(string)) + } + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if isNative { + if value == nil { + value = time.Time{} + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(time.Time) + field.Set(reflect.ValueOf(&v)) + } + } else { + field.Set(reflect.ValueOf(value)) + } + } + case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint8(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint16(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(uint)) { + v := uint(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := uint32(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(uint64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := int8(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := int16(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(int)) { + v := int(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := int32(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(int64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType&IsIntegerField > 0: + if fieldType&IsPositiveIntegerField > 0 { + if isNative { + if value == nil { + value = uint64(0) + } + field.SetUint(value.(uint64)) + } + } else { + if isNative { + if ni, ok := field.Interface().(sql.NullInt64); ok { + if value == nil { + ni.Valid = false + } else { + ni.Int64 = value.(int64) + ni.Valid = true + } + field.Set(reflect.ValueOf(ni)) + } else { + if value == nil { + value = int64(0) + } + field.SetInt(value.(int64)) + } + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if isNative { + if nf, ok := field.Interface().(sql.NullFloat64); ok { + if value == nil { + nf.Valid = false + } else { + nf.Float64 = value.(float64) + nf.Valid = true + } + field.Set(reflect.ValueOf(nf)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + if field.Type() == reflect.TypeOf(new(float32)) { + v := float32(value.(float64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := value.(float64) + field.Set(reflect.ValueOf(&v)) + } + } + } else { + + if value == nil { + value = float64(0) + } + field.SetFloat(value.(float64)) + } + } + case fieldType&IsRelField > 0: + if value != nil { + fieldType = fi.relModelInfo.fields.pk.fieldType + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = f + goto setValue + } + } + + if !isNative { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) + return nil, err + } + } + + return value, nil +} + +// query sql, read values , save to *[]ParamList. +func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { + + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch v := container.(type) { + case *[]Params: + d := *v + if len(d) == 0 { + maps = d + } + typ = 1 + case *[]ParamsList: + d := *v + if len(d) == 0 { + lists = d + } + typ = 2 + case *ParamsList: + d := *v + if len(d) == 0 { + list = d + } + typ = 3 + default: + panic(fmt.Errorf("unsupport read values type `%T`", container)) + } + + tables := newDbTables(mi, d.ins) + + var ( + cols []string + infos []*fieldInfo + ) + + hasExprs := len(exprs) > 0 + + Q := d.ins.TableQuote() + + if hasExprs { + cols = make([]string, 0, len(exprs)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, ex := range exprs { + index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", ex)) + } + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) + infos = append(infos, fi) + } + } else { + cols = make([]string, 0, len(mi.fields.dbcols)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, fi := range mi.fields.fieldsDB { + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) + infos = append(infos, fi) + } + } + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, qs.offset, qs.limit) + join := tables.getJoinSQL() + + sels := strings.Join(cols, ", ") + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + d.ins.ReplaceMarks(&query) + + rs, err := q.Query(query, args...) + if err != nil { + return 0, err + } + refs := make([]interface{}, len(cols)) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + var ( + cnt int64 + columns []string + ) + for rs.Next() { + if cnt == 0 { + cols, err := rs.Columns() + if err != nil { + return 0, err + } + columns = cols + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params[columns[i]] = value + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params = append(params, value) + } + lists = append(lists, params) + case 3: + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + list = append(list, value) + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { + return 0, nil +} + +// flag of update joined record. +func (d *dbBase) SupportUpdateJoin() bool { + return true +} + +func (d *dbBase) MaxLimit() uint64 { + return 18446744073709551615 +} + +// return quote. +func (d *dbBase) TableQuote() string { + return "`" +} + +// replace value placeholder in parametered sql string. +func (d *dbBase) ReplaceMarks(query *string) { + // default use `?` as mark, do nothing +} + +// flag of RETURNING sql. +func (d *dbBase) HasReturningID(*modelInfo, *string) bool { + return false +} + +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + +// convert time from db. +func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// convert time to db. +func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// get database types. +func (d *dbBase) DbTypes() map[string]string { + return nil +} + +// gt all tables. +func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { + tables := make(map[string]bool) + query := d.ins.ShowTablesQuery() + rows, err := db.Query(query) + if err != nil { + return tables, err + } + + defer rows.Close() + + for rows.Next() { + var table string + err := rows.Scan(&table) + if err != nil { + return tables, err + } + if table != "" { + tables[table] = true + } + } + + return tables, nil +} + +// get all cloumns in table. +func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + columns := make(map[string][3]string) + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return columns, err + } + + defer rows.Close() + + for rows.Next() { + var ( + name string + typ string + null string + ) + err := rows.Scan(&name, &typ, &null) + if err != nil { + return columns, err + } + columns[name] = [3]string{name, typ, null} + } + + return columns, nil +} + +// not implement. +func (d *dbBase) OperatorSQL(operator string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowTablesQuery() string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowColumnsQuery(table string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) IndexExists(dbQuerier, string, string) bool { + panic(ErrNotImplement) +} diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go new file mode 100644 index 00000000..bf6c350c --- /dev/null +++ b/pkg/orm/db_alias.go @@ -0,0 +1,466 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + lru "github.com/hashicorp/golang-lru" + "reflect" + "sync" + "time" +) + +// DriverType database driver constant int. +type DriverType int + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL // mysql + DRSqlite // sqlite + DROracle // oracle + DRPostgres // pgsql + DRTiDB // TiDB +) + +// database driver string. +type driver string + +// get type constant int of current driver.. +func (d driver) Type() DriverType { + a, _ := dataBaseCache.get(string(d)) + return a.Driver +} + +// get name of current driver +func (d driver) Name() string { + return string(d) +} + +// check driver iis implemented Driver interface or not. +var _ Driver = new(driver) + +var ( + dataBaseCache = &_dbCache{cache: make(map[string]*alias)} + drivers = map[string]DriverType{ + "mysql": DRMySQL, + "postgres": DRPostgres, + "sqlite3": DRSqlite, + "tidb": DRTiDB, + "oracle": DROracle, + "oci8": DROracle, // github.com/mattn/go-oci8 + "ora": DROracle, //https://github.com/rana/ora + } + dbBasers = map[DriverType]dbBaser{ + DRMySQL: newdbBaseMysql(), + DRSqlite: newdbBaseSqlite(), + DROracle: newdbBaseOracle(), + DRPostgres: newdbBasePostgres(), + DRTiDB: newdbBaseTidb(), + } +) + +// database alias cacher. +type _dbCache struct { + mux sync.RWMutex + cache map[string]*alias +} + +// add database alias with original name. +func (ac *_dbCache) add(name string, al *alias) (added bool) { + ac.mux.Lock() + defer ac.mux.Unlock() + if _, ok := ac.cache[name]; !ok { + ac.cache[name] = al + added = true + } + return +} + +// get database alias if cached. +func (ac *_dbCache) get(name string) (al *alias, ok bool) { + ac.mux.RLock() + defer ac.mux.RUnlock() + al, ok = ac.cache[name] + return +} + +// get default alias. +func (ac *_dbCache) getDefault() (al *alias) { + al, _ = ac.get("default") + return +} + +type DB struct { + *sync.RWMutex + DB *sql.DB + stmtDecorators *lru.Cache +} + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +//su must call release to release *sql.Stmt after using +func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { + d.RLock() + c, ok := d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() + return c.(*stmtDecorator), nil + } + d.RUnlock() + + d.Lock() + c, ok = d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.Unlock() + return c.(*stmtDecorator), nil + } + + stmt, err := d.Prepare(query) + if err != nil { + d.Unlock() + return nil, err + } + sd := newStmtDecorator(stmt) + sd.acquire() + d.stmtDecorators.Add(query, sd) + d.Unlock() + + return sd, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Exec(args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Query(args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRow(args...) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRowContext(ctx, args) +} + +type alias struct { + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string +} + +func detectTZ(al *alias) { + // orm timezone system match database + // default use Local + al.TZ = DefaultTimeLoc + + if al.DriverName == "sphinx" { + return + } + + switch al.Driver { + case DRMySQL: + row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") + var tz string + row.Scan(&tz) + if len(tz) >= 8 { + if tz[0] != '-' { + tz = "+" + tz + } + t, err := time.Parse("-07:00:00", tz) + if err == nil { + if t.Location().String() != "" { + al.TZ = t.Location() + } + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } + + // get default engine from current database + row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") + var engine string + var tx bool + row.Scan(&engine, &tx) + + if engine != "" { + al.Engine = engine + } else { + al.Engine = "INNODB" + } + + case DRSqlite, DROracle: + al.TZ = time.UTC + + case DRPostgres: + row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") + var tz string + row.Scan(&tz) + loc, err := time.LoadLocation(tz) + if err == nil { + al.TZ = loc + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } +} + +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + _, err := addAliasWthDB(aliasName, driverName, db) + return err +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + + for i, v := range params { + switch i { + case 0: + SetMaxIdleConns(al.Name, v) + case 1: + SetMaxOpenConns(al.Name, v) + } + } + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + if t, ok := drivers[driverName]; !ok { + drivers[driverName] = typ + } else { + if t != typ { + return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) + } + } + return nil +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + if al, ok := dataBaseCache.get(aliasName); ok { + al.TZ = tz + } else { + return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) + } + return nil +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) + // for tip go 1.2 + if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { + fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) + } +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + var name string + if len(aliasNames) > 0 { + name = aliasNames[0] + } else { + name = "default" + } + al, ok := dataBaseCache.get(name) + if ok { + return al.DB.DB, nil + } + return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) +} + +type stmtDecorator struct { + wg sync.WaitGroup + stmt *sql.Stmt +} + +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this +func (s *stmtDecorator) acquire() { + s.wg.Add(1) +} + +func (s *stmtDecorator) release() { + s.wg.Done() +} + +//garbage recycle for stmt +func (s *stmtDecorator) destroy() { + go func() { + s.wg.Wait() + _ = s.stmt.Close() + }() +} + +func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { + return &stmtDecorator{ + stmt: sqlStmt, + } +} + +func newStmtDecoratorLruWithEvict() *lru.Cache { + cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { + value.(*stmtDecorator).destroy() + }) + return cache +} diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go new file mode 100644 index 00000000..6e99058e --- /dev/null +++ b/pkg/orm/db_mysql.go @@ -0,0 +1,183 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strings" +) + +// mysql operators. +var mysqlOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", + // "regex": "REGEXP BINARY ?", + // "iregex": "REGEXP ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE BINARY ?", + "endswith": "LIKE BINARY ?", + "istartswith": "LIKE ?", + "iendswith": "LIKE ?", +} + +// mysql column field types. +var mysqlTypes = map[string]string{ + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", +} + +// mysql dbBaser implementation. +type dbBaseMysql struct { + dbBase +} + +var _ dbBaser = new(dbBaseMysql) + +// get mysql operator. +func (d *dbBaseMysql) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseMysql) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseMysql) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseMysql) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +// Add "`" for mysql sql building +func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + var iouStr string + argsMap := map[string]string{} + + iouStr = "ON DUPLICATE KEY UPDATE" + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if valueStr != "" { + updates[i] = "`" + v + "`" + "=" + valueStr + } else { + updates[i] = "`" + v + "`" + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + return id, err +} + +// create new mysql dbBaser. +func newdbBaseMysql() dbBaser { + b := new(dbBaseMysql) + b.ins = b + return b +} diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go new file mode 100644 index 00000000..5d121f83 --- /dev/null +++ b/pkg/orm/db_oracle.go @@ -0,0 +1,137 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// oracle operators. +var oracleOperators = map[string]string{ + "exact": "= ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "//iendswith": "LIKE ?", +} + +// oracle column field types. +var oracleTypes = map[string]string{ + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", +} + +// oracle dbBaser +type dbBaseOracle struct { + dbBase +} + +var _ dbBaser = new(dbBaseOracle) + +// create oracle dbBaser. +func newdbBaseOracle() dbBaser { + b := new(dbBaseOracle) + b.ins = b + return b +} + +// OperatorSQL get oracle operator. +func (d *dbBaseOracle) OperatorSQL(operator string) string { + return oracleOperators[operator] +} + +// DbTypes get oracle table field types. +func (d *dbBaseOracle) DbTypes() map[string]string { + return oracleTypes +} + +//ShowTablesQuery show all the tables in database +func (d *dbBaseOracle) ShowTablesQuery() string { + return "SELECT TABLE_NAME FROM USER_TABLES" +} + +// Oracle +func (d *dbBaseOracle) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ + "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) +} + +// check index is exist +func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ + "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ + "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) + + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = ":" + names[i] + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go new file mode 100644 index 00000000..c488fb38 --- /dev/null +++ b/pkg/orm/db_postgres.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" +) + +// postgresql operators. +var postgresOperators = map[string]string{ + "exact": "= ?", + "iexact": "= UPPER(?)", + "contains": "LIKE ?", + "icontains": "LIKE UPPER(?)", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ?", + "endswith": "LIKE ?", + "istartswith": "LIKE UPPER(?)", + "iendswith": "LIKE UPPER(?)", +} + +// postgresql column field types. +var postgresTypes = map[string]string{ + "auto": "serial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", +} + +// postgresql dbBaser. +type dbBasePostgres struct { + dbBase +} + +var _ dbBaser = new(dbBasePostgres) + +// get postgresql operator. +func (d *dbBasePostgres) OperatorSQL(operator string) string { + return postgresOperators[operator] +} + +// generate functioned sql string, such as contains(text). +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + switch operator { + case "contains", "startswith", "endswith": + *leftCol = fmt.Sprintf("%s::text", *leftCol) + case "iexact", "icontains", "istartswith", "iendswith": + *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) + } +} + +// postgresql unsupports updating joined record. +func (d *dbBasePostgres) SupportUpdateJoin() bool { + return false +} + +func (d *dbBasePostgres) MaxLimit() uint64 { + return 0 +} + +// postgresql quote is ". +func (d *dbBasePostgres) TableQuote() string { + return `"` +} + +// postgresql value placeholder is $n. +// replace default ? to $n. +func (d *dbBasePostgres) ReplaceMarks(query *string) { + q := *query + num := 0 + for _, c := range q { + if c == '?' { + num++ + } + } + if num == 0 { + return + } + data := make([]byte, 0, len(q)+num) + num = 1 + for i := 0; i < len(q); i++ { + c := q[i] + if c == '?' { + data = append(data, '$') + data = append(data, []byte(strconv.Itoa(num))...) + num++ + } else { + data = append(data, c) + } + } + *query = string(data) +} + +// make returning sql support for postgresql. +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { + fi := mi.fields.pk + if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { + return false + } + + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + } + return true +} + +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil +} + +// show table sql for postgresql. +func (d *dbBasePostgres) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" +} + +// show table columns sql for postgresql. +func (d *dbBasePostgres) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) +} + +// get column types of postgresql. +func (d *dbBasePostgres) DbTypes() map[string]string { + return postgresTypes +} + +// check index exist in postgresql. +func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) + row := db.QueryRow(query) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new postgresql dbBaser. +func newdbBasePostgres() dbBaser { + b := new(dbBasePostgres) + b.ins = b + return b +} diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go new file mode 100644 index 00000000..1d62ee34 --- /dev/null +++ b/pkg/orm/db_sqlite.go @@ -0,0 +1,161 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// sqlite operators. +var sqliteOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ? ESCAPE '\\'", + "contains": "LIKE ? ESCAPE '\\'", + "icontains": "LIKE ? ESCAPE '\\'", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ? ESCAPE '\\'", + "endswith": "LIKE ? ESCAPE '\\'", + "istartswith": "LIKE ? ESCAPE '\\'", + "iendswith": "LIKE ? ESCAPE '\\'", +} + +// sqlite column types. +var sqliteTypes = map[string]string{ + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", +} + +// sqlite dbBaser. +type dbBaseSqlite struct { + dbBase +} + +var _ dbBaser = new(dbBaseSqlite) + +// override base db read for update behavior as SQlite does not support syntax +func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + if isForUpdate { + DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") + } + return d.dbBase.Read(q, mi, ind, tz, cols, false) +} + +// get sqlite operator. +func (d *dbBaseSqlite) OperatorSQL(operator string) string { + return sqliteOperators[operator] +} + +// generate functioned sql for sqlite. +// only support DATE(text). +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + if fi.fieldType == TypeDateField { + *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) + } +} + +// unable updating joined record in sqlite. +func (d *dbBaseSqlite) SupportUpdateJoin() bool { + return false +} + +// max int in sqlite. +func (d *dbBaseSqlite) MaxLimit() uint64 { + return 9223372036854775807 +} + +// get column types in sqlite. +func (d *dbBaseSqlite) DbTypes() map[string]string { + return sqliteTypes +} + +// get show tables sql in sqlite. +func (d *dbBaseSqlite) ShowTablesQuery() string { + return "SELECT name FROM sqlite_master WHERE type = 'table'" +} + +// get columns in sqlite. +func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return nil, err + } + + columns := make(map[string][3]string) + for rows.Next() { + var tmp, name, typ, null sql.NullString + err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) + if err != nil { + return nil, err + } + columns[name.String] = [3]string{name.String, typ.String, null.String} + } + + return columns, nil +} + +// get show columns sql in sqlite. +func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { + return fmt.Sprintf("pragma table_info('%s')", table) +} + +// check index exist in sqlite. +func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("PRAGMA index_list('%s')", table) + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() + for rows.Next() { + var tmp, index sql.NullString + rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) + if name == index.String { + return true + } + } + return false +} + +// create new sqlite dbBaser. +func newdbBaseSqlite() dbBaser { + b := new(dbBaseSqlite) + b.ins = b + return b +} diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go new file mode 100644 index 00000000..4b21a6fc --- /dev/null +++ b/pkg/orm/db_tables.go @@ -0,0 +1,482 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" + "time" +) + +// table info struct. +type dbTable struct { + id int + index string + name string + names []string + sel bool + inner bool + mi *modelInfo + fi *fieldInfo + jtl *dbTable +} + +// tables collection struct, contains some tables. +type dbTables struct { + tablesM map[string]*dbTable + tables []*dbTable + mi *modelInfo + base dbBaser + skipEnd bool +} + +// set table info to collection. +// if not exist, create new. +func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { + name := strings.Join(names, ExprSep) + if j, ok := t.tablesM[name]; ok { + j.name = name + j.mi = mi + j.fi = fi + j.inner = inner + } else { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + } + return t.tablesM[name] +} + +// add table info to collection. +func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { + name := strings.Join(names, ExprSep) + if _, ok := t.tablesM[name]; !ok { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + return jt, true + } + return t.tablesM[name], false +} + +// get table info in collection. +func (t *dbTables) get(name string) (*dbTable, bool) { + j, ok := t.tablesM[name] + return j, ok +} + +// get related fields info in recursive depth loop. +// loop once, depth decreases one. +func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { + if depth < 0 || fi.fieldType == RelManyToMany { + return related + } + + if prefix == "" { + prefix = fi.name + } else { + prefix = prefix + ExprSep + fi.name + } + related = append(related, prefix) + + depth-- + for _, fi := range fi.relModelInfo.fields.fieldsRel { + related = t.loopDepth(depth, prefix, fi, related) + } + + return related +} + +// parse related fields. +func (t *dbTables) parseRelated(rels []string, depth int) { + + relsNum := len(rels) + related := make([]string, relsNum) + copy(related, rels) + + relDepth := depth + + if relsNum != 0 { + relDepth = 0 + } + + relDepth-- + for _, fi := range t.mi.fields.fieldsRel { + related = t.loopDepth(relDepth, "", fi, related) + } + + for i, s := range related { + var ( + exs = strings.Split(s, ExprSep) + names = make([]string, 0, len(exs)) + mmi = t.mi + cancel = true + jtl *dbTable + ) + + inner := true + + for _, ex := range exs { + if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { + names = append(names, fi.name) + mmi = fi.relModelInfo + + if fi.null || t.skipEnd { + inner = false + } + + jt := t.set(names, mmi, fi, inner) + jt.jtl = jtl + + if fi.reverse { + cancel = false + } + + if cancel { + jt.sel = depth > 0 + + if i < relsNum { + jt.sel = true + } + } + + jtl = jt + + } else { + panic(fmt.Errorf("unknown model/table name `%s`", ex)) + } + } + } +} + +// generate join string. +func (t *dbTables) getJoinSQL() (join string) { + Q := t.base.TableQuote() + + for _, jt := range t.tables { + if jt.inner { + join += "INNER JOIN " + } else { + join += "LEFT OUTER JOIN " + } + var ( + table string + t1, t2 string + c1, c2 string + ) + t1 = "T0" + if jt.jtl != nil { + t1 = jt.jtl.index + } + t2 = jt.index + table = jt.mi.table + + switch { + case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + c1 = jt.fi.mi.fields.pk.column + for _, ffi := range jt.mi.fields.fieldsRel { + if jt.fi.mi == ffi.relModelInfo { + c2 = ffi.column + break + } + } + default: + c1 = jt.fi.column + c2 = jt.fi.relModelInfo.fields.pk.column + + if jt.fi.reverse { + c1 = jt.mi.fields.pk.column + c2 = jt.fi.reverseFieldInfo.column + } + } + + join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, + t2, Q, c2, Q, t1, Q, c1, Q) + } + return +} + +// parse orm model struct field tag expression. +func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { + var ( + jtl *dbTable + fi *fieldInfo + fiN *fieldInfo + mmi = mi + ) + + num := len(exprs) - 1 + var names []string + + inner := true + +loopFor: + for i, ex := range exprs { + + var ok, okN bool + + if fiN != nil { + fi = fiN + ok = true + fiN = nil + } + + if i == 0 { + fi, ok = mmi.fields.GetByAny(ex) + } + + _ = okN + + if ok { + + isRel := fi.rel || fi.reverse + + names = append(names, fi.name) + + switch { + case fi.rel: + mmi = fi.relModelInfo + if fi.fieldType == RelManyToMany { + mmi = fi.relThroughModelInfo + } + case fi.reverse: + mmi = fi.reverseFieldInfo.mi + } + + if i < num { + fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + } + + if isRel && (!fi.mi.isThrough || num != i) { + if fi.null || t.skipEnd { + inner = false + } + + if t.skipEnd && okN || !t.skipEnd { + if t.skipEnd && okN && fiN.pk { + goto loopEnd + } + + jt, _ := t.add(names, mmi, fi, inner) + jt.jtl = jtl + jtl = jt + } + + } + + if num != i { + continue + } + + loopEnd: + + if i == 0 || jtl == nil { + index = "T0" + } else { + index = jtl.index + } + + info = fi + + if jtl == nil { + name = fi.name + } else { + name = jtl.name + ExprSep + fi.name + } + + switch { + case fi.rel: + + case fi.reverse: + switch fi.reverseFieldInfo.fieldType { + case RelOneToOne, RelForeignKey: + index = jtl.index + info = fi.reverseFieldInfo.mi.fields.pk + name = info.name + } + } + + break loopFor + + } else { + index = "" + name = "" + info = nil + success = false + return + } + } + + success = index != "" && info != nil + return +} + +// generate condition sql. +func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { + if cond == nil || cond.IsEmpty() { + return + } + + Q := t.base.TableQuote() + + mi := t.mi + + for i, p := range cond.params { + if i > 0 { + if p.isOr { + where += "OR " + } else { + where += "AND " + } + } + if p.isNot { + where += "NOT " + } + if p.isCond { + w, ps := t.getCondSQL(p.cond, true, tz) + if w != "" { + w = fmt.Sprintf("( %s) ", w) + } + where += w + params = append(params, ps...) + } else { + exprs := p.exprs + + num := len(exprs) - 1 + operator := "" + if operators[exprs[num]] { + operator = exprs[num] + exprs = exprs[:num] + } + + index, _, fi, suc := t.parseExprs(mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) + } + + if operator == "" { + operator = "exact" + } + + var operSQL string + var args []interface{} + if p.isRaw { + operSQL = p.sql + } else { + operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) + } + + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) + + where += fmt.Sprintf("%s %s ", leftCol, operSQL) + params = append(params, args...) + + } + } + + if !sub && where != "" { + where = "WHERE " + where + } + + return +} + +// generate group sql. +func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { + if len(groups) == 0 { + return + } + + Q := t.base.TableQuote() + + groupSqls := make([]string, 0, len(groups)) + for _, group := range groups { + exprs := strings.Split(group, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + } + + groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) + return +} + +// generate order sql. +func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { + if len(orders) == 0 { + return + } + + Q := t.base.TableQuote() + + orderSqls := make([]string, 0, len(orders)) + for _, order := range orders { + asc := "ASC" + if order[0] == '-' { + asc = "DESC" + order = order[1:] + } + exprs := strings.Split(order, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + } + + orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + return +} + +// generate limit sql. +func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { + if limit == 0 { + limit = int64(DefaultRowsLimit) + } + if limit < 0 { + // no limit + if offset > 0 { + maxLimit := t.base.MaxLimit() + if maxLimit == 0 { + limits = fmt.Sprintf("OFFSET %d", offset) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + } + } + } else if offset <= 0 { + limits = fmt.Sprintf("LIMIT %d", limit) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) + } + return +} + +// crete new tables collection. +func newDbTables(mi *modelInfo, base dbBaser) *dbTables { + tables := &dbTables{} + tables.tablesM = make(map[string]*dbTable) + tables.mi = mi + tables.base = base + return tables +} diff --git a/pkg/orm/db_tidb.go b/pkg/orm/db_tidb.go new file mode 100644 index 00000000..6020a488 --- /dev/null +++ b/pkg/orm/db_tidb.go @@ -0,0 +1,63 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" +) + +// mysql dbBaser implementation. +type dbBaseTidb struct { + dbBase +} + +var _ dbBaser = new(dbBaseTidb) + +// get mysql operator. +func (d *dbBaseTidb) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseTidb) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseTidb) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseTidb) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new mysql dbBaser. +func newdbBaseTidb() dbBaser { + b := new(dbBaseTidb) + b.ins = b + return b +} diff --git a/pkg/orm/db_utils.go b/pkg/orm/db_utils.go new file mode 100644 index 00000000..7ae10ca5 --- /dev/null +++ b/pkg/orm/db_utils.go @@ -0,0 +1,177 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "time" +) + +// get table alias. +func getDbAlias(name string) *alias { + if al, ok := dataBaseCache.get(name); ok { + return al + } + panic(fmt.Errorf("unknown DataBase alias name %s", name)) +} + +// get pk column info. +func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.fields.pk + + v := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsPositiveIntegerField > 0 { + vu := v.Uint() + exist = vu > 0 + value = vu + } else if fi.fieldType&IsIntegerField > 0 { + vu := v.Int() + exist = true + value = vu + } else if fi.fieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) + } else { + vu := v.String() + exist = vu != "" + value = vu + } + + column = fi.column + return +} + +// get fields description as flatted string. +func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { + +outFor: + for _, arg := range args { + val := reflect.ValueOf(arg) + + if arg == nil { + params = append(params, arg) + continue + } + + kind := val.Kind() + if kind == reflect.Ptr { + val = val.Elem() + kind = val.Kind() + arg = val.Interface() + } + + switch kind { + case reflect.String: + v := val.String() + if fi != nil { + if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + var t time.Time + var err error + if len(v) >= 19 { + s := v[:19] + t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) + } else if len(v) >= 10 { + s := v + if len(v) > 10 { + s = v[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else { + s := v + if len(s) > 8 { + s = v[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + if err == nil { + if fi.fieldType == TypeDateField { + v = t.In(tz).Format(formatDate) + } else if fi.fieldType == TypeDateTimeField { + v = t.In(tz).Format(formatDateTime) + } else { + v = t.In(tz).Format(formatTime) + } + } + } + } + arg = v + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg = val.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + arg = val.Uint() + case reflect.Float32: + arg, _ = StrTo(ToStr(arg)).Float64() + case reflect.Float64: + arg = val.Float() + case reflect.Bool: + arg = val.Bool() + case reflect.Slice, reflect.Array: + if _, ok := arg.([]byte); ok { + continue outFor + } + + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } + + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := getFlatParams(fi, args, tz) + params = append(params, p...) + } + continue outFor + case reflect.Struct: + if v, ok := arg.(time.Time); ok { + if fi != nil && fi.fieldType == TypeDateField { + arg = v.In(tz).Format(formatDate) + } else if fi != nil && fi.fieldType == TypeDateTimeField { + arg = v.In(tz).Format(formatDateTime) + } else if fi != nil && fi.fieldType == TypeTimeField { + arg = v.In(tz).Format(formatTime) + } else { + arg = v.In(tz).Format(formatDateTime) + } + } else { + typ := val.Type() + name := getFullName(typ) + var value interface{} + if mmi, ok := modelCache.getByFullName(name); ok { + if _, vu, exist := getExistPk(mmi, val); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) + } + } + } + + params = append(params, arg) + } + return +} diff --git a/pkg/orm/models.go b/pkg/orm/models.go new file mode 100644 index 00000000..4776bcba --- /dev/null +++ b/pkg/orm/models.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "sync" +) + +const ( + odCascade = "cascade" + odSetNULL = "set_null" + odSetDefault = "set_default" + odDoNothing = "do_nothing" + defaultStructTagName = "orm" + defaultStructTagDelim = ";" +) + +var ( + modelCache = &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +) + +// model info collection +type _modelCache struct { + sync.RWMutex // only used outsite for bootStrap + orders []string + cache map[string]*modelInfo + cacheByFullName map[string]*modelInfo + done bool +} + +// get all model info +func (mc *_modelCache) all() map[string]*modelInfo { + m := make(map[string]*modelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +// get ordered model info +func (mc *_modelCache) allOrdered() []*modelInfo { + m := make([]*modelInfo, 0, len(mc.orders)) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) + } + return m +} + +// get model info by table name +func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +// get model info by full name +func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] + return +} + +// set model info to collection +func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + mc.cacheByFullName[mi.fullName] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} + +// clean all model info. +func (mc *_modelCache) clean() { + mc.orders = make([]string, 0) + mc.cache = make(map[string]*modelInfo) + mc.cacheByFullName = make(map[string]*modelInfo) + mc.done = false +} + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + modelCache.clean() +} diff --git a/pkg/orm/models_boot.go b/pkg/orm/models_boot.go new file mode 100644 index 00000000..8c56b3c4 --- /dev/null +++ b/pkg/orm/models_boot.go @@ -0,0 +1,347 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" + "runtime/debug" + "strings" +) + +// register models. +// PrefixOrSuffix means table name prefix or suffix. +// isPrefix whether the prefix is prefix or suffix +func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) + } + + table := getTableName(val) + + if PrefixOrSuffix != "" { + if isPrefix { + table = PrefixOrSuffix + table + } else { + table = table + PrefixOrSuffix + } + } + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := modelCache.getByFullName(name); ok { + fmt.Printf(" model `%s` repeat register, must be unique\n", name) + os.Exit(2) + } + + if _, ok := modelCache.get(table); ok { + fmt.Printf(" table name `%s` repeat register, must be unique\n", table) + os.Exit(2) + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + os.Exit(2) + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + modelCache.set(table, mi) +} + +// bootstrap models +func bootStrap() { + if modelCache.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := modelCache.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := modelCache.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := modelCache.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + os.Exit(2) + } +} + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModel must be run before BootStrap")) + } + RegisterModelWithPrefix("", models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(prefix, model, true) + } +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(suffix, model, false) + } +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + modelCache.Lock() + defer modelCache.Unlock() + if modelCache.done { + return + } + bootStrap() + modelCache.done = true +} diff --git a/pkg/orm/models_fields.go b/pkg/orm/models_fields.go new file mode 100644 index 00000000..b4fad94f --- /dev/null +++ b/pkg/orm/models_fields.go @@ -0,0 +1,783 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "time" +) + +// Define the Type enum +const ( + TypeBooleanField = 1 << iota + TypeVarCharField + TypeCharField + TypeTextField + TypeTimeField + TypeDateField + TypeDateTimeField + TypeBitField + TypeSmallIntegerField + TypeIntegerField + TypeBigIntegerField + TypePositiveBitField + TypePositiveSmallIntegerField + TypePositiveIntegerField + TypePositiveBigIntegerField + TypeFloatField + TypeDecimalField + TypeJSONField + TypeJsonbField + RelForeignKey + RelOneToOne + RelManyToMany + RelReverseOne + RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 + IsRelField = ^-RelReverseMany >> 18 << 19 + IsFieldType = ^-RelReverseMany<<1 + 1 +) + +// BooleanField A true/false field. +type BooleanField bool + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return bool(e) +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + *e = BooleanField(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return strconv.FormatBool(e.Value()) +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return TypeBooleanField +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + switch d := value.(type) { + case bool: + e.Set(d) + case string: + v, err := StrTo(d).Bool() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return e.Value() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField string + +// Value return the CharField's Value +func (e CharField) Value() string { + return string(e) +} + +// Set CharField value +func (e *CharField) Set(d string) { + *e = CharField(d) +} + +// String return the CharField +func (e *CharField) String() string { + return e.Value() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return TypeVarCharField +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return e.Value() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField time.Time + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + *e = TimeField(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return e.Value() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField time.Time + +// Value return the time.Time +func (e DateField) Value() time.Time { + return time.Time(e) +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + *e = DateField(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDate) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return e.Value() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField time.Time + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + *e = DateTimeField(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return e.Value().String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return TypeDateTimeField +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDateTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return e.Value() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField float64 + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return float64(e) +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + *e = FloatField(d) +} + +// String return the string +func (e *FloatField) String() string { + return ToStr(e.Value(), -1, 32) +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return TypeFloatField +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + switch d := value.(type) { + case float32: + e.Set(float64(d)) + case float64: + e.Set(d) + case string: + v, err := StrTo(d).Float64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return e.Value() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField int16 + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return int16(e) +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + *e = SmallIntegerField(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return TypeSmallIntegerField +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int16: + e.Set(d) + case string: + v, err := StrTo(d).Int16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField int32 + +// Value return the int32 +func (e IntegerField) Value() int32 { + return int32(e) +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + *e = IntegerField(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return TypeIntegerField +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int32: + e.Set(d) + case string: + v, err := StrTo(d).Int32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return e.Value() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField int64 + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return int64(e) +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + *e = BigIntegerField(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return TypeBigIntegerField +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int64: + e.Set(d) + case string: + v, err := StrTo(d).Int64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField uint16 + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return uint16(e) +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + *e = PositiveSmallIntegerField(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return TypePositiveSmallIntegerField +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint16: + e.Set(d) + case string: + v, err := StrTo(d).Uint16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField uint32 + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return uint32(e) +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + *e = PositiveIntegerField(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint32: + e.Set(d) + case string: + v, err := StrTo(d).Uint32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField uint64 + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return uint64(e) +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + *e = PositiveBigIntegerField(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint64: + e.Set(d) + case string: + v, err := StrTo(d).Uint64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField string + +// Value return TextField value +func (e TextField) Value() string { + return string(e) +} + +// Set the TextField value +func (e *TextField) Set(d string) { + *e = TextField(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return e.Value() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return TypeTextField +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return e.Value() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField string + +// Value return JSONField value +func (j JSONField) Value() string { + return string(j) +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + *j = JSONField(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return TypeJSONField +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return j.Value() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField string + +// Value return JsonbField value +func (j JsonbField) Value() string { + return string(j) +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + *j = JsonbField(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return TypeJsonbField +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return j.Value() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/orm/models_info_f.go b/pkg/orm/models_info_f.go new file mode 100644 index 00000000..7044b0bd --- /dev/null +++ b/pkg/orm/models_info_f.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var errSkipField = errors.New("skip field") + +// field info collection +type fields struct { + pk *fieldInfo + columns map[string]*fieldInfo + fields map[string]*fieldInfo + fieldsLow map[string]*fieldInfo + fieldsByType map[int][]*fieldInfo + fieldsRel []*fieldInfo + fieldsReverse []*fieldInfo + fieldsDB []*fieldInfo + rels []*fieldInfo + orders []string + dbcols []string +} + +// add field info +func (f *fields) Add(fi *fieldInfo) (added bool) { + if f.fields[fi.name] == nil && f.columns[fi.column] == nil { + f.columns[fi.column] = fi + f.fields[fi.name] = fi + f.fieldsLow[strings.ToLower(fi.name)] = fi + } else { + return + } + if _, ok := f.fieldsByType[fi.fieldType]; !ok { + f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + } + f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) + f.orders = append(f.orders, fi.column) + if fi.dbcol { + f.dbcols = append(f.dbcols, fi.column) + f.fieldsDB = append(f.fieldsDB, fi) + } + if fi.rel { + f.fieldsRel = append(f.fieldsRel, fi) + } + if fi.reverse { + f.fieldsReverse = append(f.fieldsReverse, fi) + } + return true +} + +// get field info by name +func (f *fields) GetByName(name string) *fieldInfo { + return f.fields[name] +} + +// get field info by column name +func (f *fields) GetByColumn(column string) *fieldInfo { + return f.columns[column] +} + +// get field info by string, name is prior +func (f *fields) GetByAny(name string) (*fieldInfo, bool) { + if fi, ok := f.fields[name]; ok { + return fi, ok + } + if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + return fi, ok + } + if fi, ok := f.columns[name]; ok { + return fi, ok + } + return nil, false +} + +// create new field info collection +func newFields() *fields { + f := new(fields) + f.fields = make(map[string]*fieldInfo) + f.fieldsLow = make(map[string]*fieldInfo) + f.columns = make(map[string]*fieldInfo) + f.fieldsByType = make(map[int][]*fieldInfo) + return f +} + +// single field info +type fieldInfo struct { + mi *modelInfo + fieldIndex []int + fieldType int + dbcol bool // table column fk and onetoone + inModel bool + name string + fullName string + column string + addrValue reflect.Value + sf reflect.StructField + auto bool + pk bool + null bool + index bool + unique bool + colDefault bool // whether has default tag + initial StrTo // store the default value + size int + toText bool + autoNow bool + autoNowAdd bool + rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true + reverse bool + reverseField string + reverseFieldInfo *fieldInfo + reverseFieldInfoTwo *fieldInfo + reverseFieldInfoM2M *fieldInfo + relTable string + relThrough string + relThroughModelInfo *modelInfo + relModelInfo *modelInfo + digits int + decimals int + isFielder bool // implement Fielder interface + onDelete string + description string +} + +// new field info +func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { + var ( + tag string + tagValue string + initial StrTo // store the default value + fieldType int + attrs map[string]bool + tags map[string]string + addrField reflect.Value + ) + + fi = new(fieldInfo) + + // if field which CanAddr is the follow type + // A value is addressable if it is an element of a slice, + // an element of an addressable array, a field of an + // addressable struct, or the result of dereferencing a pointer. + addrField = field + if field.CanAddr() && field.Kind() != reflect.Ptr { + addrField = field.Addr() + if _, ok := addrField.Interface().(Fielder); !ok { + if field.Kind() == reflect.Slice { + addrField = field + } + } + } + + attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) + + if _, ok := attrs["-"]; ok { + return nil, errSkipField + } + + digits := tags["digits"] + decimals := tags["decimals"] + size := tags["size"] + onDelete := tags["on_delete"] + + initial.Clear() + if v, ok := tags["default"]; ok { + initial.Set(v) + } + +checkType: + switch f := addrField.Interface().(type) { + case Fielder: + fi.isFielder = true + if field.Kind() == reflect.Ptr { + err = fmt.Errorf("the model Fielder can not be use ptr") + goto end + } + fieldType = f.FieldType() + if fieldType&IsRelField > 0 { + err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") + goto end + } + default: + tag = "rel" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "fk": + fieldType = RelForeignKey + break checkType + case "one": + fieldType = RelOneToOne + break checkType + case "m2m": + fieldType = RelManyToMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("rel only allow these value: fk, one, m2m") + goto wrongTag + } + } + tag = "reverse" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "one": + fieldType = RelReverseOne + break checkType + case "many": + fieldType = RelReverseMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("reverse only allow these value: one, many") + goto wrongTag + } + } + + fieldType, err = getFieldType(addrField) + if err != nil { + goto end + } + if fieldType == TypeVarCharField { + switch tags["type"] { + case "char": + fieldType = TypeCharField + case "text": + fieldType = TypeTextField + case "json": + fieldType = TypeJSONField + case "jsonb": + fieldType = TypeJsonbField + } + } + if fieldType == TypeFloatField && (digits != "" || decimals != "") { + fieldType = TypeDecimalField + } + if fieldType == TypeDateTimeField && tags["type"] == "date" { + fieldType = TypeDateField + } + if fieldType == TypeTimeField && tags["type"] == "time" { + fieldType = TypeTimeField + } + } + + // check the rel and reverse type + // rel should Ptr + // reverse should slice []*struct + switch fieldType { + case RelForeignKey, RelOneToOne, RelReverseOne: + if field.Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) + goto end + } + case RelManyToMany, RelReverseMany: + if field.Kind() != reflect.Slice { + err = fmt.Errorf("rel/reverse:many field must be slice") + goto end + } else { + if field.Type().Elem().Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) + goto end + } + } + } + + if fieldType&IsFieldType == 0 { + err = fmt.Errorf("wrong field type") + goto end + } + + fi.fieldType = fieldType + fi.name = sf.Name + fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.addrValue = addrField + fi.sf = sf + fi.fullName = mi.fullName + mName + "." + sf.Name + + fi.description = tags["description"] + fi.null = attrs["null"] + fi.index = attrs["index"] + fi.auto = attrs["auto"] + fi.pk = attrs["pk"] + fi.unique = attrs["unique"] + + // Mark object property if there is attribute "default" in the orm configuration + if _, ok := tags["default"]; ok { + fi.colDefault = true + } + + switch fieldType { + case RelManyToMany, RelReverseMany, RelReverseOne: + fi.null = false + fi.index = false + fi.auto = false + fi.pk = false + fi.unique = false + default: + fi.dbcol = true + } + + switch fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + fi.rel = true + if fieldType == RelOneToOne { + fi.unique = true + } + case RelReverseMany, RelReverseOne: + fi.reverse = true + } + + if fi.rel && fi.dbcol { + switch onDelete { + case odCascade, odDoNothing: + case odSetDefault: + if !initial.Exist() { + err = errors.New("on_delete: set_default need set field a default value") + goto end + } + case odSetNULL: + if !fi.null { + err = errors.New("on_delete: set_null need set field null") + goto end + } + default: + if onDelete == "" { + onDelete = odCascade + } else { + err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) + goto end + } + } + + fi.onDelete = onDelete + } + + switch fieldType { + case TypeBooleanField: + case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: + if size != "" { + v, e := StrTo(size).Int32() + if e != nil { + err = fmt.Errorf("wrong size value `%s`", size) + } else { + fi.size = int(v) + } + } else { + fi.size = 255 + fi.toText = true + } + case TypeTextField: + fi.index = false + fi.unique = false + case TypeTimeField, TypeDateField, TypeDateTimeField: + if attrs["auto_now"] { + fi.autoNow = true + } else if attrs["auto_now_add"] { + fi.autoNowAdd = true + } + case TypeFloatField: + case TypeDecimalField: + d1 := digits + d2 := decimals + v1, er1 := StrTo(d1).Int8() + v2, er2 := StrTo(d2).Int8() + if er1 != nil || er2 != nil { + err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) + goto end + } + fi.digits = int(v1) + fi.decimals = int(v2) + default: + switch { + case fieldType&IsIntegerField > 0: + case fieldType&IsRelField > 0: + } + } + + if fieldType&IsIntegerField == 0 { + if fi.auto { + err = fmt.Errorf("non-integer type cannot set auto") + goto end + } + } + + if fi.auto || fi.pk { + if fi.auto { + switch addrField.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + default: + err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) + goto end + } + fi.pk = true + } + fi.null = false + fi.index = false + fi.unique = false + } + + if fi.unique { + fi.index = false + } + + // can not set default for these type + if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { + initial.Clear() + } + + if initial.Exist() { + v := initial + switch fieldType { + case TypeBooleanField: + _, err = v.Bool() + case TypeFloatField, TypeDecimalField: + _, err = v.Float64() + case TypeBitField: + _, err = v.Int8() + case TypeSmallIntegerField: + _, err = v.Int16() + case TypeIntegerField: + _, err = v.Int32() + case TypeBigIntegerField: + _, err = v.Int64() + case TypePositiveBitField: + _, err = v.Uint8() + case TypePositiveSmallIntegerField: + _, err = v.Uint16() + case TypePositiveIntegerField: + _, err = v.Uint32() + case TypePositiveBigIntegerField: + _, err = v.Uint64() + } + if err != nil { + tag, tagValue = "default", tags["default"] + goto wrongTag + } + } + + fi.initial = initial +end: + if err != nil { + return nil, err + } + return +wrongTag: + return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) +} diff --git a/pkg/orm/models_info_m.go b/pkg/orm/models_info_m.go new file mode 100644 index 00000000..a4d733b6 --- /dev/null +++ b/pkg/orm/models_info_m.go @@ -0,0 +1,148 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" +) + +// single model info +type modelInfo struct { + pkg string + name string + fullName string + table string + model interface{} + fields *fields + manual bool + addrField reflect.Value //store the original struct value + uniques []string + isThrough bool +} + +// new model info +func newModelInfo(val reflect.Value) (mi *modelInfo) { + mi = &modelInfo{} + mi.fields = newFields() + ind := reflect.Indirect(val) + mi.addrField = val + mi.name = ind.Type().Name() + mi.fullName = getFullName(ind.Type()) + addModelFields(mi, ind, "", []int{}) + return +} + +// index: FieldByIndex returns the nested field corresponding to index +func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { + var ( + err error + fi *fieldInfo + sf reflect.StructField + ) + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + // if the field is unexported skip + if sf.PkgPath != "" { + continue + } + // add anonymous struct fields + if sf.Anonymous { + addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) + continue + } + + fi, err = newFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { + break + } + //record current field index + fi.fieldIndex = append(fi.fieldIndex, index...) + fi.fieldIndex = append(fi.fieldIndex, i) + fi.mi = mi + fi.inModel = true + if !mi.fields.Add(fi) { + err = fmt.Errorf("duplicate column name: %s", fi.column) + break + } + if fi.pk { + if mi.fields.pk != nil { + err = fmt.Errorf("one model must have one pk field only") + break + } else { + mi.fields.pk = fi + } + } + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } +} + +// combine related model info to new model info. +// prepare for relation models query. +func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { + mi = new(modelInfo) + mi.fields = newFields() + mi.table = m1.table + "_" + m2.table + "s" + mi.name = camelString(mi.table) + mi.fullName = m1.pkg + "." + mi.name + + fa := new(fieldInfo) // pk + f1 := new(fieldInfo) // m1 table RelForeignKey + f2 := new(fieldInfo) // m2 table RelForeignKey + fa.fieldType = TypeBigIntegerField + fa.auto = true + fa.pk = true + fa.dbcol = true + fa.name = "Id" + fa.column = "id" + fa.fullName = mi.fullName + "." + fa.name + + f1.dbcol = true + f2.dbcol = true + f1.fieldType = RelForeignKey + f2.fieldType = RelForeignKey + f1.name = camelString(m1.table) + f2.name = camelString(m2.table) + f1.fullName = mi.fullName + "." + f1.name + f2.fullName = mi.fullName + "." + f2.name + f1.column = m1.table + "_id" + f2.column = m2.table + "_id" + f1.rel = true + f2.rel = true + f1.relTable = m1.table + f2.relTable = m2.table + f1.relModelInfo = m1 + f2.relModelInfo = m2 + f1.mi = mi + f2.mi = mi + + mi.fields.Add(fa) + mi.fields.Add(f1) + mi.fields.Add(f2) + mi.fields.pk = fa + + mi.uniques = []string{f1.column, f2.column} + return +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go new file mode 100644 index 00000000..e3a635f2 --- /dev/null +++ b/pkg/orm/models_test.go @@ -0,0 +1,497 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" +) + +// A slice string field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeVarCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(SliceStringField) + +// A json field. +type JSONFieldTest struct { + Name string + Data string +} + +func (e *JSONFieldTest) String() string { + data, _ := json.Marshal(e) + return string(data) +} + +func (e *JSONFieldTest) FieldType() int { + return TypeTextField +} + +func (e *JSONFieldTest) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + return json.Unmarshal([]byte(d), e) + default: + return fmt.Errorf(" unknown value `%v`", value) + } +} + +func (e *JSONFieldTest) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(JSONFieldTest) + +type Data struct { + ID int `orm:"column(id)"` + Boolean bool + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` + Date time.Time `orm:"type(date)"` + DateTime time.Time `orm:"column(datetime)"` + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + ID int `orm:"column(id)"` + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + +// only for mysql +type UserBig struct { + ID uint64 `orm:"column(id)"` + Name string +} + +type User struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool +} + +func (u *User) TableIndex() [][]string { + return [][]string{ + {"Id", "UserName"}, + {"Id", "Created"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + {"UserName", "Email"}, + } +} + +func NewUser() *User { + obj := new(User) + return obj +} + +type Profile struct { + ID int `orm:"column(id)"` + Age int16 + Money float64 + User *User `orm:"reverse(one)" json:"-"` + BestPost *Post `orm:"rel(one);null"` +} + +func (u *Profile) TableName() string { + return "user_profile" +} + +func NewProfile() *Profile { + obj := new(Profile) + return obj +} + +type Post struct { + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` +} + +func (u *Post) TableIndex() [][]string { + return [][]string{ + {"Id", "Created"}, + } +} + +func NewPost() *Post { + obj := new(Post) + return obj +} + +type Tag struct { + ID int `orm:"column(id)"` + Name string `orm:"size(30)"` + BestPost *Post `orm:"rel(one);null"` + Posts []*Post `orm:"reverse(many)" json:"-"` +} + +func NewTag() *Tag { + obj := new(Tag) + return obj +} + +type PostTags struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk)"` + Tag *Tag `orm:"rel(fk)"` +} + +func (m *PostTags) TableName() string { + return "prefix_post_tags" +} + +type Comment struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk);column(post)"` + Content string `orm:"type(text)"` + Parent *Comment `orm:"null;rel(fk)"` + Created time.Time `orm:"auto_now_add"` +} + +func NewComment() *Comment { + obj := new(Comment) + return obj +} + +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + +type ModelID struct { + ID int64 +} + +type ModelBase struct { + ModelID + + Created time.Time `orm:"auto_now_add;type(datetime)"` + Updated time.Time `orm:"auto_now;type(datetime)"` +} + +type InLine struct { + // Common Fields + ModelBase + + // Other Fields + Name string `orm:"unique"` + Email string +} + +func NewInLine() *InLine { + return new(InLine) +} + +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" +) + +var ( + dORM Ormer + dDbBaser dbBaser +) + +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + +func init() { + Debug, _ = StrTo(DBARGS.Debug).Bool() + + if DBARGS.Driver == "" || DBARGS.Source == "" { + fmt.Println(helpinfo) + os.Exit(2) + } + + RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DRMySQL { + alias.Engine = "INNODB" + } + +} diff --git a/pkg/orm/models_utils.go b/pkg/orm/models_utils.go new file mode 100644 index 00000000..71127a6b --- /dev/null +++ b/pkg/orm/models_utils.go @@ -0,0 +1,227 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +// 1 is attr +// 2 is tag +var supportTag = map[string]int{ + "-": 1, + "null": 1, + "index": 1, + "unique": 1, + "pk": 1, + "auto": 1, + "auto_now": 1, + "auto_now_add": 1, + "size": 2, + "column": 2, + "default": 2, + "rel": 2, + "reverse": 2, + "rel_table": 2, + "rel_through": 2, + "digits": 2, + "decimals": 2, + "on_delete": 2, + "type": 2, + "description": 2, +} + +// get reflect.Type name with package path. +func getFullName(typ reflect.Type) string { + return typ.PkgPath() + "." + typ.Name() +} + +// getTableName get struct table name. +// If the struct implement the TableName, then get the result as tablename +// else use the struct name which will apply snakeString. +func getTableName(val reflect.Value) string { + if fun := val.MethodByName("TableName"); fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + // has return and the first val is string + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return snakeString(reflect.Indirect(val).Type().Name()) +} + +// get table engine, myisam or innodb. +func getTableEngine(val reflect.Value) string { + fun := val.MethodByName("TableEngine") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return "" +} + +// get table index from method. +func getTableIndex(val reflect.Value) [][]string { + fun := val.MethodByName("TableIndex") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get table unique from method +func getTableUnique(val reflect.Value) [][]string { + fun := val.MethodByName("TableUnique") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get snaked column name +func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { + column := col + if col == "" { + column = nameStrategyMap[nameStrategy](sf.Name) + } + switch ft { + case RelForeignKey, RelOneToOne: + if len(col) == 0 { + column = column + "_id" + } + case RelManyToMany, RelReverseMany, RelReverseOne: + column = sf.Name + } + return column +} + +// return field type as type constant from reflect.Value +func getFieldType(val reflect.Value) (ft int, err error) { + switch val.Type() { + case reflect.TypeOf(new(int8)): + ft = TypeBitField + case reflect.TypeOf(new(int16)): + ft = TypeSmallIntegerField + case reflect.TypeOf(new(int32)), + reflect.TypeOf(new(int)): + ft = TypeIntegerField + case reflect.TypeOf(new(int64)): + ft = TypeBigIntegerField + case reflect.TypeOf(new(uint8)): + ft = TypePositiveBitField + case reflect.TypeOf(new(uint16)): + ft = TypePositiveSmallIntegerField + case reflect.TypeOf(new(uint32)), + reflect.TypeOf(new(uint)): + ft = TypePositiveIntegerField + case reflect.TypeOf(new(uint64)): + ft = TypePositiveBigIntegerField + case reflect.TypeOf(new(float32)), + reflect.TypeOf(new(float64)): + ft = TypeFloatField + case reflect.TypeOf(new(bool)): + ft = TypeBooleanField + case reflect.TypeOf(new(string)): + ft = TypeVarCharField + case reflect.TypeOf(new(time.Time)): + ft = TypeDateTimeField + default: + elm := reflect.Indirect(val) + switch elm.Kind() { + case reflect.Int8: + ft = TypeBitField + case reflect.Int16: + ft = TypeSmallIntegerField + case reflect.Int32, reflect.Int: + ft = TypeIntegerField + case reflect.Int64: + ft = TypeBigIntegerField + case reflect.Uint8: + ft = TypePositiveBitField + case reflect.Uint16: + ft = TypePositiveSmallIntegerField + case reflect.Uint32, reflect.Uint: + ft = TypePositiveIntegerField + case reflect.Uint64: + ft = TypePositiveBigIntegerField + case reflect.Float32, reflect.Float64: + ft = TypeFloatField + case reflect.Bool: + ft = TypeBooleanField + case reflect.String: + ft = TypeVarCharField + default: + if elm.Interface() == nil { + panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) + } + switch elm.Interface().(type) { + case sql.NullInt64: + ft = TypeBigIntegerField + case sql.NullFloat64: + ft = TypeFloatField + case sql.NullBool: + ft = TypeBooleanField + case sql.NullString: + ft = TypeVarCharField + case time.Time: + ft = TypeDateTimeField + } + } + } + if ft&IsFieldType == 0 { + err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) + } + return +} + +// parse struct tag string +func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { + attrs = make(map[string]bool) + tags = make(map[string]string) + for _, v := range strings.Split(data, defaultStructTagDelim) { + if v == "" { + continue + } + v = strings.TrimSpace(v) + if t := strings.ToLower(v); supportTag[t] == 1 { + attrs[t] = true + } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { + name := t[:i] + if supportTag[name] == 2 { + v = v[i+1 : len(v)-1] + tags[name] = v + } + } else { + DebugLog.Println("unsupport orm tag", v) + } + } + return +} diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go new file mode 100644 index 00000000..0551b1cd --- /dev/null +++ b/pkg/orm/orm.go @@ -0,0 +1,579 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "reflect" + "sync" + "time" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = false + DebugLog = NewLog(os.Stdout) + DefaultRowsLimit = -1 + DefaultRelsDepth = 2 + DefaultTimeLoc = time.Local + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +// Params stores the Params +type Params map[string]interface{} + +// ParamsList stores paramslist +type ParamsList []interface{} + +type orm struct { + alias *alias + db dbQuerier + isTx bool +} + +var _ Ormer = new(orm) + +// get model info and model reflect value +func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if needPtr && val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + name := getFullName(typ) + if mi, ok := modelCache.getByFullName(name); ok { + return mi, ind + } + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) +} + +// get field info from model info by given field name +func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { + fi, ok := mi.fields.GetByAny(name) + if !ok { + panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) + } + return fi +} + +// read data to model +func (o *orm) Read(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + cols = append([]string{col1}, cols...) + mi, ind := o.getMiInd(md, true) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + if err == ErrNoRows { + // Create + id, err := o.Insert(md) + return (err == nil), id, err + } + + id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else if mi.fields.pk.rel { + return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + } else { + id = vid.Int() + } + + return false, id, err +} + +// insert model data to database +func (o *orm) Insert(md interface{}) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// set auto pk field +func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) + } + } +} + +// insert some models to database +func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := reflect.Indirect(sind.Index(i)) + mi, _ := o.getMiInd(ind.Interface(), false) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt++ + } + } else { + mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + } + return cnt, nil +} + +// InsertOrUpdate data to database +func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// update model to database. +// cols set the columns those want to update. +func (o *orm) Update(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + if err != nil { + return num, err + } + if num > 0 { + o.setPk(mi, ind, 0) + } + return num, nil +} + +// create a models to models queryer +func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + switch { + case fi.fieldType == RelManyToMany: + case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: + default: + panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) + } + + return newQueryM2M(md, o, mi, fi, ind) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + _, fi, ind, qseter := o.queryRelated(md, name) + + qs := qseter.(*querySet) + + var relDepth int + var limit, offset int64 + var order string + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + relDepth = DefaultRelsDepth + } + } else if v, ok := arg.(int); ok { + relDepth = v + } + case 1: + limit = ToInt64(arg) + case 2: + offset = ToInt64(arg) + case 3: + order, _ = arg.(string) + } + } + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + limit = 1 + offset = 0 + } + + qs.limit = limit + qs.offset = offset + qs.relDepth = relDepth + + if len(order) > 0 { + qs.orders = []string{order} + } + + find := ind.FieldByIndex(fi.fieldIndex) + + var nums int64 + var err error + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + val := reflect.New(find.Type().Elem()) + container := val.Interface() + err = qs.One(container) + if err == nil { + find.Set(val) + nums = 1 + } + default: + nums, err = qs.All(find.Addr().Interface()) + } + + return nums, err +} + +// return a QuerySeter for related models to md model. +// it can do all, update, delete in QuerySeter. +// example: +// qs := orm.QueryRelated(post,"Tag") +// qs.All(&[]*Tag{}) +// +func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { + // is this api needed ? + _, _, _, qs := o.queryRelated(md, name) + return qs +} + +// get QuerySeter for related models to md model +func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + _, _, exist := getExistPk(mi, ind) + if !exist { + panic(ErrMissPK) + } + + var qs *querySet + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + if !fi.inModel { + break + } + qs = o.getRelQs(md, mi, fi) + case RelReverseOne, RelReverseMany: + if !fi.inModel { + break + } + qs = o.getReverseQs(md, mi, fi) + } + + if qs == nil { + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) + } + + return mi, fi, ind, qs +} + +// get reverse relation QuerySeter +func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelReverseOne, RelReverseMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) + } + + var q *querySet + + if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { + q = newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// get relation QuerySeter +func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) + } + + q := newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition() + + if fi.fieldType == RelManyToMany { + q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q.cond = q.cond.And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + var name string + if table, ok := ptrStructOrTableName.(string); ok { + name = nameStrategyMap[defaultNameStrategy](table) + if mi, ok := modelCache.get(name); ok { + qs = newQuerySet(o, mi) + } + } else { + name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + if mi, ok := modelCache.getByFullName(name); ok { + qs = newQuerySet(o, mi) + } + } + if qs == nil { + panic(fmt.Errorf(" table name: `%s` not exists", name)) + } + return +} + +// switch to another registered database driver by given name. +func (o *orm) Using(name string) error { + if o.isTx { + panic(fmt.Errorf(" transaction has been start, cannot change db")) + } + if al, ok := dataBaseCache.get(name); ok { + o.alias = al + if Debug { + o.db = newDbQueryLog(al, al.DB) + } else { + o.db = al.DB + } + } else { + return fmt.Errorf(" unknown db alias name `%s`", name) + } + return nil +} + +// begin transaction +func (o *orm) Begin() error { + return o.BeginTx(context.Background(), nil) +} + +func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + var tx *sql.Tx + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return err + } + o.isTx = true + if Debug { + o.db.(*dbQueryLog).SetDB(tx) + } else { + o.db = tx + } + return nil +} + +// commit transaction +func (o *orm) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Commit() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *orm) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Rollback() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *orm) Raw(query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *orm) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *orm) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats + } + return nil +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + err := o.Using("default") + if err != nil { + panic(err) + } + return o +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + var al *alias + + if dr, ok := drivers[driverName]; ok { + al = new(alias) + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + detectTZ(al) + + o := new(orm) + o.alias = al + + if Debug { + o.db = newDbQueryLog(o.alias, db) + } else { + o.db = db + } + + return o, nil +} diff --git a/pkg/orm/orm_conds.go b/pkg/orm/orm_conds.go new file mode 100644 index 00000000..f3fd66f0 --- /dev/null +++ b/pkg/orm/orm_conds.go @@ -0,0 +1,153 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +type condValue struct { + exprs []string + args []interface{} + cond *Condition + isOr bool + isNot bool + isCond bool + isRaw bool + sql string +} + +// Condition struct. +// work for WHERE conditions. +type Condition struct { + params []condValue +} + +// NewCondition return new condition struct +func NewCondition() *Condition { + c := &Condition{} + return c +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + if len(sql) == 0 { + panic(fmt.Errorf(" sql cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) + return &c +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) + return &c +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) + return &c +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true}) + } + return c +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) + } + return c +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) + return &c +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) + return &c +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) + } + return c +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) + } + return c +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return len(c.params) == 0 +} + +// clone clone a condition +func (c Condition) clone() *Condition { + return &c +} diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go new file mode 100644 index 00000000..f107bb59 --- /dev/null +++ b/pkg/orm/orm_log.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + "io" + "log" + "strings" + "time" +) + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +//costomer log func +var LogFunc func(query map[string]interface{}) + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + d := new(Log) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) + return d +} + +func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { + var logMap = make(map[string]interface{}) + sub := time.Now().Sub(t) / 1e5 + elsp := float64(int(sub)) / 10.0 + logMap["cost_time"] = elsp + flag := " OK" + if err != nil { + flag = "FAIL" + } + logMap["flag"] = flag + con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) + cons := make([]string, 0, len(args)) + for _, arg := range args { + cons = append(cons, fmt.Sprintf("%v", arg)) + } + if len(cons) > 0 { + con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) + } + if err != nil { + con += " - " + err.Error() + } + logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) + if LogFunc != nil{ + LogFunc(logMap) + } + DebugLog.Println(con) +} + +// statement query logger struct. +// if dev mode, use stmtQueryLog, or use stmtQuerier. +type stmtQueryLog struct { + alias *alias + query string + stmt stmtQuerier +} + +var _ stmtQuerier = new(stmtQueryLog) + +func (d *stmtQueryLog) Close() error { + a := time.Now() + err := d.stmt.Close() + debugLogQueies(d.alias, "st.Close", d.query, a, err) + return err +} + +func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.stmt.Exec(args...) + debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.stmt.Query(args...) + debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + a := time.Now() + res := d.stmt.QueryRow(args...) + debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) + return res +} + +func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { + d := new(stmtQueryLog) + d.stmt = stmt + d.alias = alias + d.query = query + return d +} + +// database query logger struct. +// if dev mode, use dbQueryLog, or use dbQuerier. +type dbQueryLog struct { + alias *alias + db dbQuerier + tx txer + txe txEnder +} + +var _ dbQuerier = new(dbQueryLog) +var _ txer = new(dbQueryLog) +var _ txEnder = new(dbQueryLog) + +func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.Prepare(query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.PrepareContext(ctx, query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.Exec(query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.ExecContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.Query(query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.QueryContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRow(query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRowContext(ctx, query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) Begin() (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).Begin() + debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).BeginTx(ctx, opts) + debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) Commit() error { + a := time.Now() + err := d.db.(txEnder).Commit() + debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) + return err +} + +func (d *dbQueryLog) Rollback() error { + a := time.Now() + err := d.db.(txEnder).Rollback() + debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) + return err +} + +func (d *dbQueryLog) SetDB(db dbQuerier) { + d.db = db +} + +func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { + d := new(dbQueryLog) + d.alias = alias + d.db = db + return d +} diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go new file mode 100644 index 00000000..de3181ce --- /dev/null +++ b/pkg/orm/orm_object.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" +) + +// an insert queryer struct +type insertSet struct { + mi *modelInfo + orm *orm + stmt stmtQuerier + closed bool +} + +var _ Inserter = new(insertSet) + +// insert model ignore it's registered or not. +func (o *insertSet) Insert(md interface{}) (int64, error) { + if o.closed { + return 0, ErrStmtClosed + } + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := getFullName(typ) + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) + } + if name != o.mi.fullName { + panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) + } + id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + if err != nil { + return id, err + } + if id > 0 { + if o.mi.fields.pk.auto { + if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) + } + } + } + return id, nil +} + +// close insert queryer statement +func (o *insertSet) Close() error { + if o.closed { + return ErrStmtClosed + } + o.closed = true + return o.stmt.Close() +} + +// create new insert queryer. +func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { + bi := new(insertSet) + bi.orm = orm + bi.mi = mi + st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + if err != nil { + return nil, err + } + if Debug { + bi.stmt = newStmtQueryLog(orm.alias, st, query) + } else { + bi.stmt = st + } + return bi, nil +} diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go new file mode 100644 index 00000000..6a270a0d --- /dev/null +++ b/pkg/orm/orm_querym2m.go @@ -0,0 +1,140 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "reflect" + +// model to model struct +type queryM2M struct { + md interface{} + mi *modelInfo + fi *fieldInfo + qs *querySet + ind reflect.Value +} + +// add models to origin models when creating queryM2M. +// example: +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} +// +// make sure the relation is defined in post model struct tag. +func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + fi := o.fi + mi := fi.relThroughModelInfo + mfi := fi.reverseFieldInfo + rfi := fi.reverseFieldInfoTwo + + orm := o.qs.orm + dbase := orm.alias.DbBaser + + var models []interface{} + var otherValues []interface{} + var otherNames []string + + for _, colname := range mi.fields.dbcols { + if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && + mi.fields.columns[colname] != mi.fields.pk { + otherNames = append(otherNames, colname) + } + } + for i, md := range mds { + if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { + otherValues = append(otherValues, md) + mds = append(mds[:i], mds[i+1:]...) + } + } + for _, md := range mds { + val := reflect.ValueOf(md) + if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + if v.CanInterface() { + models = append(models, v.Interface()) + } + } + } else { + models = append(models, md) + } + } + + _, v1, exist := getExistPk(o.mi, o.ind) + if !exist { + panic(ErrMissPK) + } + + names := []string{mfi.column, rfi.column} + + values := make([]interface{}, 0, len(models)*2) + for _, md := range models { + + ind := reflect.Indirect(reflect.ValueOf(md)) + var v2 interface{} + if ind.Kind() != reflect.Struct { + v2 = ind.Interface() + } else { + _, v2, exist = getExistPk(fi.relModelInfo, ind) + if !exist { + panic(ErrMissPK) + } + } + values = append(values, v1, v2) + + } + names = append(names, otherNames...) + values = append(values, otherValues...) + return dbase.InsertValue(orm.db, mi, true, names, values) +} + +// remove models following the origin model relationship +func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + fi := o.fi + qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) + + return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() +} + +// check model is existed in relationship of origin model +func (o *queryM2M) Exist(md interface{}) bool { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md). + Filter(fi.reverseFieldInfoTwo.name, md).Exist() +} + +// clean all models in related of origin model +func (o *queryM2M) Clear() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() +} + +// count all related models of origin model +func (o *queryM2M) Count() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() +} + +var _ QueryM2Mer = new(queryM2M) + +// create new M2M queryer. +func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { + qm2m := new(queryM2M) + qm2m.md = md + qm2m.mi = mi + qm2m.fi = fi + qm2m.ind = ind + qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + return qm2m +} diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go new file mode 100644 index 00000000..878b836b --- /dev/null +++ b/pkg/orm/orm_queryset.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "fmt" +) + +type colValue struct { + value int64 + opt operator +} + +type operator int + +// define Col operations +const ( + ColAdd operator = iota + ColMinus + ColMultiply + ColExcept + ColBitAnd + ColBitRShift + ColBitLShift + ColBitXOR + ColBitOr +) + +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } +func ColValue(opt operator, value interface{}) interface{} { + switch opt { + case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, + ColBitLShift, ColBitXOR, ColBitOr: + default: + panic(fmt.Errorf("orm.ColValue wrong operator")) + } + v, err := StrTo(ToStr(value)).Int64() + if err != nil { + panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) + } + var val colValue + val.value = v + val.opt = opt + return val +} + +// real query struct +type querySet struct { + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []string + distinct bool + forupdate bool + orm *orm + ctx context.Context + forContext bool +} + +var _ QuerySeter = new(querySet) + +// add condition expression to QuerySeter. +func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.And(expr, args...) + return &o +} + +// add raw sql to querySeter. +func (o querySet) FilterRaw(expr string, sql string) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.Raw(expr, sql) + return &o +} + +// add NOT condition to querySeter. +func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.AndNot(expr, args...) + return &o +} + +// set offset number +func (o *querySet) setOffset(num interface{}) { + o.offset = ToInt64(num) +} + +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. +func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { + o.limit = ToInt64(limit) + if len(args) > 0 { + o.setOffset(args[0]) + } + return &o +} + +// add OFFSET value +func (o querySet) Offset(offset interface{}) QuerySeter { + o.setOffset(offset) + return &o +} + +// add GROUP expression +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + +// add ORDER expression. +// "column" means ASC, "-column" means DESC. +func (o querySet) OrderBy(exprs ...string) QuerySeter { + o.orders = exprs + return &o +} + +// add DISTINCT to SELECT +func (o querySet) Distinct() QuerySeter { + o.distinct = true + return &o +} + +// add FOR UPDATE to SELECT +func (o querySet) ForUpdate() QuerySeter { + o.forupdate = true + return &o +} + +// set relation model to query together. +// it will query relation models and assign to parent model. +func (o querySet) RelatedSel(params ...interface{}) QuerySeter { + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + o.related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Errorf(" wrong param kind: %v", val)) + } + } + } + return &o +} + +// set condition to QuerySeter. +func (o querySet) SetCond(cond *Condition) QuerySeter { + o.cond = cond + return &o +} + +// get condition from QuerySeter +func (o querySet) GetCond() *Condition { + return o.cond +} + +// return QuerySeter execution result number +func (o *querySet) Count() (int64, error) { + return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// check result empty or not after QuerySeter executed +func (o *querySet) Exist() bool { + cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return cnt > 0 +} + +// execute update with parameters +func (o *querySet) Update(values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) +} + +// execute delete +func (o *querySet) Delete() (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// return a insert queryer. +// it can be used in times. +// example: +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) +func (o *querySet) PrepareInsert() (Inserter, error) { + return newInsertSet(o.orm, o.mi) +} + +// query all data and map to containers. +// cols means the columns when querying. +func (o *querySet) All(container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) +} + +// query one row data and map to containers. +// cols means the columns when querying. +func (o *querySet) One(container interface{}, cols ...string) error { + o.limit = 1 + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + if err != nil { + return err + } + if num == 0 { + return ErrNoRows + } + + if num > 1 { + return ErrMultiRows + } + return nil +} + +// query all data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. +func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to [][]interface +// it converts data to [][column_index]value +func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to []interface. +// it's designed for one row record set, auto change to []value, not [][column]value. +func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// set context to QuerySeter. +func (o querySet) WithContext(ctx context.Context) QuerySeter { + o.ctx = ctx + o.forContext = true + return &o +} + +// create new QuerySeter. +func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { + o := new(querySet) + o.mi = mi + o.orm = orm + return o +} diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go new file mode 100644 index 00000000..3325a7ea --- /dev/null +++ b/pkg/orm/orm_raw.go @@ -0,0 +1,867 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// raw sql string prepared statement +type rawPrepare struct { + rs *rawSet + stmt stmtQuerier + closed bool +} + +func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { + if o.closed { + return nil, ErrStmtClosed + } + return o.stmt.Exec(args...) +} + +func (o *rawPrepare) Close() error { + o.closed = true + return o.stmt.Close() +} + +func newRawPreparer(rs *rawSet) (RawPreparer, error) { + o := new(rawPrepare) + o.rs = rs + + query := rs.query + rs.orm.alias.DbBaser.ReplaceMarks(&query) + + st, err := rs.orm.db.Prepare(query) + if err != nil { + return nil, err + } + if Debug { + o.stmt = newStmtQueryLog(rs.orm.alias, st, query) + } else { + o.stmt = st + } + return o, nil +} + +// raw query seter +type rawSet struct { + query string + args []interface{} + orm *orm +} + +var _ RawSeter = new(rawSet) + +// set args for every query +func (o rawSet) SetArgs(args ...interface{}) RawSeter { + o.args = args + return &o +} + +// execute raw sql and return sql.Result +func (o *rawSet) Exec() (sql.Result, error) { + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + return o.orm.db.Exec(query, args...) +} + +// set field value to row container +func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { + switch ind.Kind() { + case reflect.Bool: + if value == nil { + ind.SetBool(false) + } else if v, ok := value.(bool); ok { + ind.SetBool(v) + } else { + v, _ := StrTo(ToStr(value)).Bool() + ind.SetBool(v) + } + + case reflect.String: + if value == nil { + ind.SetString("") + } else { + ind.SetString(ToStr(value)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == nil { + ind.SetInt(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetInt(val.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetInt(int64(val.Uint())) + default: + v, _ := StrTo(ToStr(value)).Int64() + ind.SetInt(v) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == nil { + ind.SetUint(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetUint(uint64(val.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetUint(val.Uint()) + default: + v, _ := StrTo(ToStr(value)).Uint64() + ind.SetUint(v) + } + } + case reflect.Float64, reflect.Float32: + if value == nil { + ind.SetFloat(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Float64: + ind.SetFloat(val.Float()) + default: + v, _ := StrTo(ToStr(value)).Float64() + ind.SetFloat(v) + } + } + + case reflect.Struct: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + return + } + switch ind.Interface().(type) { + case time.Time: + var str string + switch d := value.(type) { + case time.Time: + o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) + ind.Set(reflect.ValueOf(d)) + case []byte: + str = string(d) + case string: + str = d + } + if str != "" { + if len(str) >= 19 { + str = str[:19] + t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) + if err == nil { + t = t.In(DefaultTimeLoc) + ind.Set(reflect.ValueOf(t)) + } + } else if len(str) >= 10 { + str = str[:10] + t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } + } + } + case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: + indi := reflect.New(ind.Type()).Interface() + sc, ok := indi.(sql.Scanner) + if !ok { + return + } + err := sc.Scan(value) + if err == nil { + ind.Set(reflect.Indirect(reflect.ValueOf(sc))) + } + } + + case reflect.Ptr: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + break + } + ind.Set(reflect.New(ind.Type().Elem())) + o.setFieldValue(reflect.Indirect(ind), value) + } +} + +// set field value in loop for slice container +func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { + nInds := *nIndsPtr + + cur := 0 + for i := 0; i < len(sInds); i++ { + sInd := sInds[i] + eTyp := eTyps[i] + + typ := eTyp + isPtr := false + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + + var nInd reflect.Value + if init { + nInd = reflect.New(sInd.Type()).Elem() + } else { + nInd = nInds[i] + } + + val := reflect.New(typ) + ind := val.Elem() + + tpName := ind.Type().String() + + if ind.Kind() == reflect.Struct { + if tpName == "time.Time" { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + } else { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + if nInd.Kind() == reflect.Slice { + if isPtr { + nInd = reflect.Append(nInd, val) + } else { + nInd = reflect.Append(nInd, ind) + } + } else { + if isPtr { + nInd.Set(val) + } else { + nInd.Set(ind) + } + } + + nInds[i] = nInd + } +} + +// query data and map to container +func (o *rawSet) QueryRow(containers ...interface{}) error { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" all args must be use ptr")) + } + + etyp := ind.Type() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, ind) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + + defer rows.Close() + + if rows.Next() { + if structMode { + columns, err := rows.Columns() + if err != nil { + return err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return err + } + + ind := sInds[0] + + if ind.Kind() == reflect.Ptr { + if ind.IsNil() || !ind.IsValid() { + ind.Set(reflect.New(eTyps[0].Elem())) + } + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + } else { + if err := rows.Scan(refs...); err != nil { + return err + } + + nInds := make([]reflect.Value, len(sInds)) + o.loopSetRefs(refs, sInds, &nInds, eTyps, true) + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + + } else { + return ErrNoRows + } + + return nil +} + +// query data rows and map to container +func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + sInd := reflect.Indirect(val) + if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { + panic(fmt.Errorf(" all args must be use ptr slice")) + } + + etyp := sInd.Type().Elem() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, sInd) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rows.Close() + + var cnt int64 + nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + + for rows.Next() { + + if structMode { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 && !sInd.IsNil() { + sInd.Set(reflect.New(sInd.Type()).Elem()) + } + + var ind reflect.Value + if eTyps[0].Kind() == reflect.Ptr { + ind = reflect.New(eTyps[0].Elem()) + } else { + ind = reflect.New(eTyps[0]) + } + + if ind.Kind() == reflect.Ptr { + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + // init call the recursive function + recursiveSetField(ind) + } + + if eTyps[0].Kind() == reflect.Ptr { + ind = ind.Addr() + } + + sInd = reflect.Append(sInd, ind) + + } else { + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) + } + + cnt++ + } + + if cnt > 0 { + + if structMode { + sInds[0].Set(sInd) + } else { + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + } + + return cnt, nil +} + +func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch container.(type) { + case *[]Params: + typ = 1 + case *[]ParamsList: + typ = 2 + case *ParamsList: + typ = 3 + default: + panic(fmt.Errorf(" unsupport read values type `%T`", container)) + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + var rs *sql.Rows + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + indexs []int + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + if len(needCols) > 0 { + indexs = make([]int, 0, len(needCols)) + } else { + indexs = make([]int, 0, len(columns)) + } + + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + var ref sql.NullString + refs[i] = &ref + + if len(needCols) > 0 { + for _, c := range needCols { + if c == cols[i] { + indexs = append(indexs, i) + } + } + } else { + indexs = append(indexs, i) + } + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params[cols[i]] = value.String + } else { + params[cols[i]] = nil + } + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params = append(params, value.String) + } else { + params = append(params, nil) + } + } + lists = append(lists, params) + case 3: + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + list = append(list, value.String) + } else { + list = append(list, nil) + } + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { + var ( + maps Params + ind *reflect.Value + ) + + var typ int + switch container.(type) { + case *Params: + typ = 1 + default: + typ = 2 + vl := reflect.ValueOf(container) + id := reflect.Indirect(vl) + if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { + panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) + } + + ind = &id + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + ) + + var ( + keyIndex = -1 + valueIndex = -1 + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + if keyCol == cols[i] { + keyIndex = i + } + if typ == 1 || keyIndex == i { + var ref sql.NullString + refs[i] = &ref + } else { + var ref interface{} + refs[i] = &ref + } + if valueCol == cols[i] { + valueIndex = i + } + } + if keyIndex == -1 || valueIndex == -1 { + panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 { + switch typ { + case 1: + maps = make(Params) + } + } + + key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String + + switch typ { + case 1: + value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) + if value.Valid { + maps[key] = value.String + } else { + maps[key] = nil + } + + default: + if id := ind.FieldByName(camelString(key)); id.IsValid() { + o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) + } + } + + cnt++ + } + + if typ == 1 { + v, _ := container.(*Params) + *v = maps + } + + return cnt, nil +} + +// query data to []map[string]interface +func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to [][]interface +func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to []interface +func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(result, keyCol, valueCol) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(ptrStruct, keyCol, valueCol) +} + +// return prepared raw statement for used in times. +func (o *rawSet) Prepare() (RawPreparer, error) { + return newRawPreparer(o) +} + +func newRawSet(orm *orm, query string, args []interface{}) RawSeter { + o := new(rawSet) + o.query = query + o.args = args + o.orm = orm + return o +} diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go new file mode 100644 index 00000000..bdb430b6 --- /dev/null +++ b/pkg/orm/orm_test.go @@ -0,0 +1,2494 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package orm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" +) + +var _ = os.PathSeparator + +var ( + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" +) + +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { + if len(args) == 0 { + return false, fmt.Errorf("miss args") + } + b := args[0] + arg := argAny(args) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b + } else { + err = fmt.Errorf("compare datetime miss format") + goto wrongArg + } + } + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } else { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } + } + +wrongArg: + if err != nil { + return false, err + } + + return true, nil +} + +func AssertIs(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(true, a, args...); !ok { + return err + } + return nil +} + +func AssertNot(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(false, a, args...); !ok { + return err + } + return nil +} + +func getCaller(skip int) string { + pc, file, line, _ := runtime.Caller(skip) + fun := runtime.FuncForPC(pc) + _, fn := filepath.Split(file) + data, err := ioutil.ReadFile(file) + var codes []string + if err == nil { + lines := bytes.Split(data, []byte{'\n'}) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + if code != "" { + codes = append(codes, code) + } + } + } + funName := fun.Name() + if i := strings.LastIndex(funName, "."); i > -1 { + funName = funName[i+1:] + } + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) +} + +func throwFail(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.Fail() + } +} + +func throwFailNow(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.FailNow() + } +} + +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + +func TestSyncDb(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + err := RunSyncdb("default", true, Debug) + throwFail(t, err) + + modelCache.clean() +} + +func TestRegisterModels(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + BootStrap() + + dORM = NewOrm() + dDbBaser = getDbAlias("default").DbBaser +} + +func TestModelSyntax(t *testing.T) { + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFullName(fn) + throwFail(t, AssertIs(ok, true)) + + mi, ok = modelCache.get("user") + throwFail(t, AssertIs(ok, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + } +} + +var DataValues = map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), +} + +func TestDataTypes(t *testing.T) { + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + if name == "JSON" { + continue + } + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = Data{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + d = DataNull{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.JSON, data)) + + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{ID: 2} + err = dORM.Read(&d) + throwFail(t, err) + + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{String: "test", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{ID: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) + throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) + throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) +} + +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestCRUD(t *testing.T) { + profile := NewProfile() + profile.Age = 30 + profile.Money = 1234.12 + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + u := &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + user.UserName = "astaxie" + user.Profile = profile + num, err := dORM.Update(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) + + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) + + num, err = dORM.Delete(profile) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + throwFail(t, AssertIs(true, u.Profile == nil)) + + num, err = dORM.Delete(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: 100} + err = dORM.Read(u) + throwFail(t, AssertIs(err, ErrNoRows)) + + ub := UserBig{} + ub.Name = "name" + id, err = dORM.Insert(&ub) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + ub = UserBig{ID: 1} + err = dORM.Read(&ub) + throwFail(t, err) + throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertTestData(t *testing.T) { + var users []*User + + profile := NewProfile() + profile.Age = 28 + profile.Money = 1234.12 + + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 1 + user.IsStaff = false + user.IsActive = true + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + profile = NewProfile() + profile.Age = 30 + profile.Money = 4321.09 + + id, err = dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "astaxie" + user.Email = "astaxie@gmail.com" + user.Password = "password" + user.Status = 2 + user.IsStaff = true + user.IsActive = false + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "nobody" + user.Email = "nobody@gmail.com" + user.Password = "nobody" + user.Status = 3 + user.IsStaff = false + user.IsActive = false + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 4)) + + tags := []*Tag{ + {Name: "golang", BestPost: &Post{ID: 2}}, + {Name: "example"}, + {Name: "format"}, + {Name: "c++"}, + } + + posts := []*Post{ + {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. +With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, + {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. +The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, + } + + comments := []*Comment{ + {Post: posts[0], Content: "a comment"}, + {Post: posts[1], Content: "yes"}, + {Post: posts[1]}, + {Post: posts[1]}, + {Post: posts[2]}, + {Post: posts[2]}, + } + + for _, tag := range tags { + id, err := dORM.Insert(tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, post := range posts { + id, err := dORM.Insert(post) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(post.Tags) + if num > 0 { + nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + + for _, comment := range comments { + id, err := dORM.Insert(comment) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + +} + +func TestCustomField(t *testing.T) { + user := User{ID: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + user.Extra.Name = "beego" + user.Extra.Data = "orm" + _, err = dORM.Update(&user, "Langs", "Extra") + throwFailNow(t, err) + + user = User{ID: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) + + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) +} + +func TestExpr(t *testing.T) { + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable((*User)(nil)) + qs = dORM.QueryTable("User") + qs = dORM.QueryTable("user") + num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + // throwFail(t, err) + // throwFail(t, AssertIs(num, 3)) +} + +func TestOperators(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__iexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__contains", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + var shouldNum int + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__contains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gt", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gte", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("status__lt", Uint(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__lte", Int(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("user_name__startswith", "s").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsSqlite || IsTidb { + shouldNum = 1 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__startswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__istartswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__endswith", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__endswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__iendswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("status__in", 1, 2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSetCond(t *testing.T) { + cond := NewCondition() + cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + + qs := dORM.QueryTable("user") + num, err := qs.SetCond(cond1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) +} + +func TestLimit(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Limit(-1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + num, err = qs.Limit(-1, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Limit(0, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOffset(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOrderBy(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestAll(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("Id").All(&users) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) + + var users2 []User + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").All(&users2) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + var users3 []*User + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + throwFailNow(t, AssertIs(users3 == nil, false)) +} + +func TestOne(t *testing.T) { + var user User + qs := dORM.QueryTable("user") + err := qs.One(&user) + throwFail(t, err) + + user = User{} + err = qs.OrderBy("Id").Limit(1).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "slene")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + user = User{} + err = qs.OrderBy("-Id").Limit(100).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "nobody")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + err = qs.Filter("user_name", "nothing").One(&user) + throwFail(t, AssertIs(err, ErrNoRows)) + +} + +func TestValues(t *testing.T) { + var maps []Params + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) + } + + num, err = qs.Filter("UserName", "slene").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestValuesList(t *testing.T) { + var list []ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").ValuesList(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) + } + + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) + } +} + +func TestValuesFlat(t *testing.T) { + var list ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) + } +} + +func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } + qs := dORM.QueryTable("user") + num, err := qs.Filter("profile__age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var user User + err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "slene").RelatedSel().One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) + + qs = dORM.QueryTable("user_profile") + num, err = qs.Filter("user__username", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var posts []*Post + qs = dORM.QueryTable("post") + num, err = qs.RelatedSel().All(&posts) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 4)) + + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) +} + +func TestReverseQuery(t *testing.T) { + var profile Profile + err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + profile = Profile{} + err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + var user User + err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + var posts []*Post + num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). + Filter("User__UserName", "slene").RelatedSel().All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].User == nil, false)) + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + + var tags []*Tag + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) + throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) +} + +func TestLoadRelated(t *testing.T) { + // load reverse foreign key + user := User{ID: 3} + + err := dORM.Read(&user) + throwFailNow(t, err) + + num, err := dORM.LoadRelated(&user, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) + + num, err = dORM.LoadRelated(&user, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + + num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + // load reverse one to one + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} + num, err = dORM.Update(&profile, "BestPost") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + err = dORM.Read(&profile) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&profile, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&profile, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) + + // load rel one to one + err = dORM.Read(&user) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&user, "Profile") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + num, err = dORM.LoadRelated(&user, "Profile", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) + throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) + + post := Post{ID: 2} + + // load rel foreign key + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&post, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(post.User.Profile == nil, false)) + throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) + + // load rel m2m + post = Post{ID: 2} + + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "Tags") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + + num, err = dORM.LoadRelated(&post, "Tags", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) + + // load reverse m2m + tag := Tag{ID: 1} + + err = dORM.Read(&tag) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&tag, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) + + num, err = dORM.LoadRelated(&tag, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) +} + +func TestQueryM2M(t *testing.T) { + post := Post{ID: 4} + m2m := dORM.QueryM2M(&post, "Tags") + + tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} + tag2 := &Tag{Name: "TestTag3"} + tag3 := []interface{}{&Tag{Name: "TestTag4"}} + + tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} + + for _, tag := range tags { + _, err := dORM.Insert(tag) + throwFailNow(t, err) + } + + num, err := m2m.Add(tag1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 5)) + + num, err = m2m.Remove(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + exist := m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + tag := Tag{Name: "test"} + _, err = dORM.Insert(&tag) + throwFailNow(t, err) + + m2m = dORM.QueryM2M(&tag, "Posts") + + post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} + post2 := &Post{Title: "TestPost3"} + post3 := []interface{}{&Post{Title: "TestPost4"}} + + posts := []interface{}{post1[0], post1[1], post2, post3[0]} + + for _, post := range posts { + p := post.(*Post) + p.User = &User{ID: 1} + _, err := dORM.Insert(post) + throwFailNow(t, err) + } + + num, err = m2m.Add(post1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + num, err = m2m.Remove(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + num, err = dORM.Delete(&tag) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) +} + +func TestQueryRelate(t *testing.T) { + // post := &Post{Id: 2} + + // qs := dORM.QueryRelate(post, "Tags") + // num, err := qs.Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + + // var tags []*Tag + // num, err = qs.All(&tags) + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + // throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) +} + +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + +func TestPrepareInsert(t *testing.T) { + qs := dORM.QueryTable("user") + i, err := qs.PrepareInsert() + throwFailNow(t, err) + + var user User + user.UserName = "testing1" + num, err := i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + user.UserName = "testing2" + num, err = i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + err = i.Close() + throwFail(t, err) + err = i.Close() + throwFail(t, AssertIs(err, ErrStmtClosed)) +} + +func TestRawExec(t *testing.T) { + Q := dDbBaser.TableQuote() + + query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) + res, err := dORM.Raw(query, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) + + res, err = dORM.Raw(query, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) +} + +func TestRawQueryRow(t *testing.T) { + var ( + Boolean bool + Char string + Text string + Time time.Time + Date time.Time + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 + ) + + dataValues := make(map[string]interface{}, len(DataValues)) + + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v + } + + Q := dDbBaser.TableQuote() + + cols := []string{ + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", + } + sep := fmt.Sprintf("%s, %s", Q, Q) + query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) + var id int + values := []interface{}{ + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, + } + err := dORM.Raw(query, 1).QueryRow(values...) + throwFailNow(t, err) + for i, col := range cols { + vu := values[i] + v := reflect.ValueOf(vu).Elem().Interface() + switch col { + case "id": + throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) + case "date": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) + case "datetime": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) + default: + throwFail(t, AssertIs(v, dataValues[col])) + } + } + + var ( + uid int + status *int + pid *int + ) + + cols = []string{ + "id", "Status", "profile_id", + } + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) + throwFail(t, err) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + +func TestQueryRows(t *testing.T) { + Q := dDbBaser.TableQuote() + + var datas []*Data + + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + + ind := reflect.Indirect(reflect.ValueOf(datas[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var ids []int + var usernames []string + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) + + //test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +func TestRawValues(t *testing.T) { + Q := dDbBaser.TableQuote() + + var maps []Params + query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) + num, err := dORM.Raw(query, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(query, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], "slene")) + } + + query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) + var list ParamsList + num, err = dORM.Raw(query).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) + } +} + +func TestRawPrepare(t *testing.T) { + switch { + case IsMysql || IsSqlite: + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid > 0, true)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + + case IsPostgres: + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + } + } +} + +func TestUpdate(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ + "is_staff": true, + "is_active": true, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColAdd, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMinus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMultiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColExcept, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) +} + +func TestDelete(t *testing.T) { + qs := dORM.QueryTable("user_profile") + num, err := qs.Filter("user__user_name", "slene").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestTransaction(t *testing.T) { + // this test worked when database support transaction + + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name__in", names).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + err = o.Begin() + throwFail(t, err) + + tag.Name = "commit" + id, err = o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + +} + +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestReadOrCreate(t *testing.T) { + u := &User{ + UserName: "Kyle", + Email: "kylemcc@gmail.com", + Password: "other_pass", + Status: 7, + IsStaff: false, + IsActive: true, + } + + created, pk, err := dORM.ReadOrCreate(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) + throwFail(t, AssertIs(u.UserName, "Kyle")) + throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) + throwFail(t, AssertIs(u.Password, "other_pass")) + throwFail(t, AssertIs(u.Status, 7)) + throwFail(t, AssertIs(u.IsStaff, false)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) + + nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} + created, pk, err = dORM.ReadOrCreate(nu, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.UserName, u.UserName)) + throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above + throwFail(t, AssertIs(nu.Password, u.Password)) + throwFail(t, AssertIs(nu.Status, u.Status)) + throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) + throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + + dORM.Delete(u) +} + +func TestInLine(t *testing.T) { + name := "inline" + email := "hello@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + il := NewInLine() + il.ID = 1 + err = dORM.Read(il) + throwFail(t, err) + + throwFail(t, AssertIs(il.Name, name)) + throwFail(t, AssertIs(il.Email, email)) + throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) +} + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, _, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err := dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "i_d", + "NO": "n_o", + "NOO": "n_o_o", + "NOOooOOoo": "n_o_ooo_o_ooo", + "OrderNO": "order_n_o", + "tagName": "tag_name", + "tag_Name": "tag__name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} + +func TestInsertOrUpdate(t *testing.T) { + RegisterModel(new(User)) + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + test := User{UserName: "unique_username133"} + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + //test1 + _, err := dORM.InsertOrUpdate(&user1, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + } + //test2 + _, err = dORM.InsertOrUpdate(&user2, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } + //test3 + + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + //test4 - + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + //test5 * + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + //test6 / + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/pkg/orm/qb.go b/pkg/orm/qb.go new file mode 100644 index 00000000..e0655a17 --- /dev/null +++ b/pkg/orm/qb.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "errors" + +// QueryBuilder is the Query builder interface +type QueryBuilder interface { + Select(fields ...string) QueryBuilder + ForUpdate() QueryBuilder + From(tables ...string) QueryBuilder + InnerJoin(table string) QueryBuilder + LeftJoin(table string) QueryBuilder + RightJoin(table string) QueryBuilder + On(cond string) QueryBuilder + Where(cond string) QueryBuilder + And(cond string) QueryBuilder + Or(cond string) QueryBuilder + In(vals ...string) QueryBuilder + OrderBy(fields ...string) QueryBuilder + Asc() QueryBuilder + Desc() QueryBuilder + Limit(limit int) QueryBuilder + Offset(offset int) QueryBuilder + GroupBy(fields ...string) QueryBuilder + Having(cond string) QueryBuilder + Update(tables ...string) QueryBuilder + Set(kv ...string) QueryBuilder + Delete(tables ...string) QueryBuilder + InsertInto(table string, fields ...string) QueryBuilder + Values(vals ...string) QueryBuilder + Subquery(sub string, alias string) string + String() string +} + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + if driver == "mysql" { + qb = new(MySQLQueryBuilder) + } else if driver == "tidb" { + qb = new(TiDBQueryBuilder) + } else if driver == "postgres" { + err = errors.New("postgres query builder is not supported yet") + } else if driver == "sqlite" { + err = errors.New("sqlite query builder is not supported yet") + } else { + err = errors.New("unknown driver for query builder") + } + return +} diff --git a/pkg/orm/qb_mysql.go b/pkg/orm/qb_mysql.go new file mode 100644 index 00000000..23bdc9ee --- /dev/null +++ b/pkg/orm/qb_mysql.go @@ -0,0 +1,185 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// CommaSpace is the separation +const CommaSpace = ", " + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/qb_tidb.go b/pkg/orm/qb_tidb.go new file mode 100644 index 00000000..87b3ae84 --- /dev/null +++ b/pkg/orm/qb_tidb.go @@ -0,0 +1,182 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/types.go b/pkg/orm/types.go new file mode 100644 index 00000000..2fd10774 --- /dev/null +++ b/pkg/orm/types.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" +) + +// Driver define database driver +type Driver interface { + Name() string + Type() DriverType +} + +// Fielder define field info +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} +} + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + //args[0] bool true useDefaultRelsDepth ; false depth 0 + //args[0] int loadRelationDepth + //args[1] int limit default limit 1000 + //args[2] int offset default offset 0 + //args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter interface { + Insert(interface{}) (int64, error) + Close() error +} + +// QuerySeter query seter +type QuerySeter interface { + // add condition expression to QuerySeter. + // for example: + // filter by UserName == 'slene' + // qs.Filter("UserName", "slene") + // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 + // Filter("profile__Age", 28) + // // time compare + // qs.Filter("created", time.Now()) + Filter(string, ...interface{}) QuerySeter + // add raw sql to querySeter. + // for example: + // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") + // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) + FilterRaw(string, string) QuerySeter + // add NOT condition to querySeter. + // have the same usage as Filter + Exclude(string, ...interface{}) QuerySeter + // set condition to QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond1).Count() + SetCond(*Condition) QuerySeter + // get condition from QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) + // qs = qs.SetCond(cond) + // cond = qs.GetCond() + // cond := cond.Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond).Count() + GetCond() *Condition + // add LIMIT value. + // args[0] means offset, e.g. LIMIT num,offset. + // if Limit <= 0 then Limit will be set to default limit ,eg 1000 + // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 + // for example: + // qs.Limit(10, 2) + // // sql-> limit 10 offset 2 + Limit(limit interface{}, args ...interface{}) QuerySeter + // add OFFSET value + // same as Limit function's args[0] + Offset(offset interface{}) QuerySeter + // add GROUP BY expression + // for example: + // qs.GroupBy("id") + GroupBy(exprs ...string) QuerySeter + // add ORDER expression. + // "column" means ASC, "-column" means DESC. + // for example: + // qs.OrderBy("-status") + OrderBy(exprs ...string) QuerySeter + // set relation model to query together. + // it will query relation models and assign to parent model. + // for example: + // // will load all related fields use left join . + // qs.RelatedSel().One(&user) + // // will load related field only profile + // qs.RelatedSel("profile").One(&user) + // user.Profile.Age = 32 + RelatedSel(params ...interface{}) QuerySeter + // Set Distinct + // for example: + // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). + // Distinct(). + // All(&permissions) + Distinct() QuerySeter + // set FOR UPDATE to query. + // for example: + // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) + ForUpdate() QuerySeter + // return QuerySeter execution result number + // for example: + // num, err = qs.Filter("profile__age__gt", 28).Count() + Count() (int64, error) + // check result empty or not after QuerySeter executed + // the same as QuerySeter.Count > 0 + Exist() bool + // execute update with parameters + // for example: + // num, err = qs.Filter("user_name", "slene").Update(Params{ + // "Nums": ColValue(Col_Minus, 50), + // }) // user slene's Nums will minus 50 + // num, err = qs.Filter("UserName", "slene").Update(Params{ + // "user_name": "slene2" + // }) // user slene's name will change to slene2 + Update(values Params) (int64, error) + // delete from table + //for example: + // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + // //delete two user who's name is testing1 or testing2 + Delete() (int64, error) + // return a insert queryer. + // it can be used in times. + // example: + // i,err := sq.PrepareInsert() + // num, err = i.Insert(&user1) // user table will add one record user1 at once + // num, err = i.Insert(&user2) // user table will add one record user2 at once + // err = i.Close() //don't forget call Close + PrepareInsert() (Inserter, error) + // query all data and map to containers. + // cols means the columns when querying. + // for example: + // var users []*User + // qs.All(&users) // users[0],users[1],users[2] ... + All(container interface{}, cols ...string) (int64, error) + // query one row data and map to containers. + // cols means the columns when querying. + // for example: + // var user User + // qs.One(&user) //user.UserName == "slene" + One(container interface{}, cols ...string) error + // query all data and map to []map[string]interface. + // expres means condition expression. + // it converts data to []map[column]value. + // for example: + // var maps []Params + // qs.Values(&maps) //maps[0]["UserName"]=="slene" + Values(results *[]Params, exprs ...string) (int64, error) + // query all data and map to [][]interface + // it converts data to [][column_index]value + // for example: + // var list []ParamsList + // qs.ValuesList(&list) // list[0][1] == "slene" + ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + // query all data and map to []interface. + // it's designed for one column record set, auto change to []value, not [][column]value. + // for example: + // var list ParamsList + // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" + ValuesFlat(result *ParamsList, expr string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) +} + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer interface { + // add models to origin models when creating queryM2M. + // example: + // m2m := orm.QueryM2M(post,"Tag") + // m2m.Add(&Tag1{},&Tag2{}) + // for _,tag := range post.Tags{}{ ... } + // param could also be any of the follow + // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} + // &Tag{Id:5,Name: "TestTag3"} + // []interface{}{&Tag{Id:6,Name: "TestTag4"}} + // insert one or more rows to m2m table + // make sure the relation is defined in post model struct tag. + Add(...interface{}) (int64, error) + // remove models following the origin model relationship + // only delete rows from m2m table + // for example: + //tag3 := &Tag{Id:5,Name: "TestTag3"} + //num, err = m2m.Remove(tag3) + Remove(...interface{}) (int64, error) + // check model is existed in relationship of origin model + Exist(interface{}) bool + // clean all models in related of origin model + Clear() (int64, error) + // count all related models of origin model + Count() (int64, error) +} + +// RawPreparer raw query statement +type RawPreparer interface { + Exec(...interface{}) (sql.Result, error) + Close() error +} + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter interface { + //execute sql and get result + Exec() (sql.Result, error) + //query data and map to container + //for example: + // var name string + // var id int + // rs.QueryRow(&id,&name) // id==2 name=="slene" + QueryRow(containers ...interface{}) error + + // query data rows and map to container + // var ids []int + // var names []int + // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) + // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} + QueryRows(containers ...interface{}) (int64, error) + SetArgs(...interface{}) RawSeter + // query data to []map[string]interface + // see QuerySeter's Values + Values(container *[]Params, cols ...string) (int64, error) + // query data to [][]interface + // see QuerySeter's ValuesList + ValuesList(container *[]ParamsList, cols ...string) (int64, error) + // query data to []interface + // see QuerySeter's ValuesFlat + ValuesFlat(container *ParamsList, cols ...string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + + // return prepared raw statement for used in times. + // for example: + // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) + Prepare() (RawPreparer, error) +} + +// stmtQuerier statement querier +type stmtQuerier interface { + Close() error + Exec(args ...interface{}) (sql.Result, error) + //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + Query(args ...interface{}) (*sql.Rows, error) + //QueryContext(args ...interface{}) (*sql.Rows, error) + QueryRow(args ...interface{}) *sql.Row + //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row +} + +// db querier +type dbQuerier interface { + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// type DB interface { +// Begin() (*sql.Tx, error) +// Prepare(query string) (stmtQuerier, error) +// Exec(query string, args ...interface{}) (sql.Result, error) +// Query(query string, args ...interface{}) (*sql.Rows, error) +// QueryRow(query string, args ...interface{}) *sql.Row +// } + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { + Commit() error + Rollback() error +} + +// base database struct +type dbBaser interface { + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + SupportUpdateJoin() bool + UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + OperatorSQL(string) string + GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*fieldInfo, string, *string) + PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) + MaxLimit() uint64 + TableQuote() string + ReplaceMarks(*string) + HasReturningID(*modelInfo, *string) bool + TimeFromDB(*time.Time, *time.Location) + TimeToDB(*time.Time, *time.Location) + DbTypes() map[string]string + GetTables(dbQuerier) (map[string]bool, error) + GetColumns(dbQuerier, string) (map[string][3]string, error) + ShowTablesQuery() string + ShowColumnsQuery(string) string + IndexExists(dbQuerier, string, string) bool + collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error +} diff --git a/pkg/orm/utils.go b/pkg/orm/utils.go new file mode 100644 index 00000000..3ff76772 --- /dev/null +++ b/pkg/orm/utils.go @@ -0,0 +1,319 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + "time" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo string + +// Set string +func (f *StrTo) Set(v string) { + if v != "" { + *f = StrTo(v) + } else { + f.Clear() + } +} + +// Clear string +func (f *StrTo) Clear() { + *f = StrTo(0x1E) +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return string(f) != string(0x1E) +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +// Int string to int +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err +} + +// String string to string +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/orm/utils_test.go b/pkg/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} From b50fb44950a2687aa55652d0638d74a0e2967a6e Mon Sep 17 00:00:00 2001 From: playHing Date: Sat, 11 Jul 2020 02:00:48 +0800 Subject: [PATCH 025/207] Add bench test on context input query --- context/input_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/context/input_test.go b/context/input_test.go index db812a0f..3a6c2e7b 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -205,3 +205,13 @@ func TestParams(t *testing.T) { } } +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} From 55e6298f290345f900bff9443cb8fb9a09e78965 Mon Sep 17 00:00:00 2001 From: playHing Date: Sat, 11 Jul 2020 02:06:09 +0800 Subject: [PATCH 026/207] Fix concurrent form parsing and getting --- context/input.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/context/input.go b/context/input.go index 7b522c36..46a18f03 100644 --- a/context/input.go +++ b/context/input.go @@ -333,7 +333,11 @@ func (input *BeegoInput) Query(key string) string { return val } if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } } return input.Context.Request.Form.Get(key) } From 3e2c795410da9e912f9d4b46bf862a212bec27ca Mon Sep 17 00:00:00 2001 From: playHing Date: Mon, 13 Jul 2020 23:11:23 +0800 Subject: [PATCH 027/207] Rlock for form query --- context/input.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/context/input.go b/context/input.go index 46a18f03..385549c1 100644 --- a/context/input.go +++ b/context/input.go @@ -334,11 +334,13 @@ func (input *BeegoInput) Query(key string) string { } if input.Context.Request.Form == nil { input.dataLock.Lock() - defer input.dataLock.Unlock() if input.Context.Request.Form == nil { input.Context.Request.ParseForm() } + input.dataLock.Unlock() } + input.dataLock.RLock() + defer input.dataLock.RUnlock() return input.Context.Request.Form.Get(key) } From 192a278a2a7cd2d891d3ee25da1559702b1ec180 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 19 Jul 2020 12:56:58 +0000 Subject: [PATCH 028/207] Fix orm test when using Driver = mysql --- .travis.yml | 4 +- cache/redis/redis_test.go | 6 +- go.mod | 2 +- orm/cmd_utils.go | 10 +- orm/models_test.go | 497 -------- orm/orm_test.go | 2494 ------------------------------------- orm/utils_test.go | 70 -- pkg/orm/models_test.go | 20 +- pkg/orm/orm.go | 2 +- pkg/orm/orm_test.go | 32 +- 10 files changed, 43 insertions(+), 3094 deletions(-) delete mode 100644 orm/models_test.go delete mode 100644 orm/orm_test.go delete mode 100644 orm/utils_test.go diff --git a/.travis.yml b/.travis.yml index c019c999..26c3732e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - "1.13.x" + - "1.14.x" services: - redis-server - mysql @@ -63,7 +63,7 @@ after_script: - killall -w ssdb-server - rm -rf ./res/var/* script: - - go test -v ./... + - go test ./... - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 7ac88f87..60a19180 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -21,6 +21,7 @@ import ( "github.com/astaxie/beego/cache" "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" ) func TestRedisCache(t *testing.T) { @@ -124,9 +125,8 @@ func TestCache_Scan(t *testing.T) { if err != nil { t.Error("scan Error", err) } - if len(keys) != 10000 { - t.Error("scan all err") - } + + assert.Equal(t, 10000, len(keys), "scan all error") // clear all if err = bm.ClearAll(); err != nil { diff --git a/go.mod b/go.mod index ec500f51..adca28ad 100644 --- a/go.mod +++ b/go.mod @@ -37,4 +37,4 @@ replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/gol replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d -go 1.13 +go 1.14 diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 61f17346..eac85091 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - //if fi.initial.String() != "" { + // if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - //} + // } // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/orm/models_test.go b/orm/models_test.go deleted file mode 100644 index e3a635f2..00000000 --- a/orm/models_test.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "strings" - "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" -) - -// A slice string field. -type SliceStringField []string - -func (e SliceStringField) Value() []string { - return []string(e) -} - -func (e *SliceStringField) Set(d []string) { - *e = SliceStringField(d) -} - -func (e *SliceStringField) Add(v string) { - *e = append(*e, v) -} - -func (e *SliceStringField) String() string { - return strings.Join(e.Value(), ",") -} - -func (e *SliceStringField) FieldType() int { - return TypeVarCharField -} - -func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") - v := make([]string, 0, len(parts)) - for _, p := range parts { - v = append(v, strings.TrimSpace(p)) - } - e.Set(v) - } - default: - return fmt.Errorf(" unknown value `%v`", value) - } - return nil -} - -func (e *SliceStringField) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(SliceStringField) - -// A json field. -type JSONFieldTest struct { - Name string - Data string -} - -func (e *JSONFieldTest) String() string { - data, _ := json.Marshal(e) - return string(data) -} - -func (e *JSONFieldTest) FieldType() int { - return TypeTextField -} - -func (e *JSONFieldTest) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - return json.Unmarshal([]byte(d), e) - default: - return fmt.Errorf(" unknown value `%v`", value) - } -} - -func (e *JSONFieldTest) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(JSONFieldTest) - -type Data struct { - ID int `orm:"column(id)"` - Boolean bool - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - JSON string `orm:"type(json);default({\"name\":\"json\"})"` - Jsonb string `orm:"type(jsonb)"` - Time time.Time `orm:"type(time)"` - Date time.Time `orm:"type(date)"` - DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune - Int int - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 `orm:"digits(8);decimals(4)"` -} - -type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` -} - -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 - -type DataCustom struct { - ID int `orm:"column(id)"` - Boolean Boolean - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - Byte Byte - Rune Rune - Int Int - Int8 Int8 - Int16 Int16 - Int32 Int32 - Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 - Uint32 Uint32 - Uint64 Uint64 - Float32 Float32 - Float64 Float64 - Decimal Float64 `orm:"digits(8);decimals(4)"` -} - -// only for mysql -type UserBig struct { - ID uint64 `orm:"column(id)"` - Name string -} - -type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool -} - -func (u *User) TableIndex() [][]string { - return [][]string{ - {"Id", "UserName"}, - {"Id", "Created"}, - } -} - -func (u *User) TableUnique() [][]string { - return [][]string{ - {"UserName", "Email"}, - } -} - -func NewUser() *User { - obj := new(User) - return obj -} - -type Profile struct { - ID int `orm:"column(id)"` - Age int16 - Money float64 - User *User `orm:"reverse(one)" json:"-"` - BestPost *Post `orm:"rel(one);null"` -} - -func (u *Profile) TableName() string { - return "user_profile" -} - -func NewProfile() *Profile { - obj := new(Profile) - return obj -} - -type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` -} - -func (u *Post) TableIndex() [][]string { - return [][]string{ - {"Id", "Created"}, - } -} - -func NewPost() *Post { - obj := new(Post) - return obj -} - -type Tag struct { - ID int `orm:"column(id)"` - Name string `orm:"size(30)"` - BestPost *Post `orm:"rel(one);null"` - Posts []*Post `orm:"reverse(many)" json:"-"` -} - -func NewTag() *Tag { - obj := new(Tag) - return obj -} - -type PostTags struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk)"` - Tag *Tag `orm:"rel(fk)"` -} - -func (m *PostTags) TableName() string { - return "prefix_post_tags" -} - -type Comment struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk);column(post)"` - Content string `orm:"type(text)"` - Parent *Comment `orm:"null;rel(fk)"` - Created time.Time `orm:"auto_now_add"` -} - -func NewComment() *Comment { - obj := new(Comment) - return obj -} - -type Group struct { - ID int `orm:"column(gid);size(32)"` - Name string - Permissions []*Permission `orm:"reverse(many)" json:"-"` -} - -type Permission struct { - ID int `orm:"column(id)"` - Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` -} - -type GroupPermissions struct { - ID int `orm:"column(id)"` - Group *Group `orm:"rel(fk)"` - Permission *Permission `orm:"rel(fk)"` -} - -type ModelID struct { - ID int64 -} - -type ModelBase struct { - ModelID - - Created time.Time `orm:"auto_now_add;type(datetime)"` - Updated time.Time `orm:"auto_now;type(datetime)"` -} - -type InLine struct { - // Common Fields - ModelBase - - // Other Fields - Name string `orm:"unique"` - Email string -} - -func NewInLine() *InLine { - return new(InLine) -} - -type InLineOneToOne struct { - // Common Fields - ModelBase - - Note string - InLine *InLine `orm:"rel(fk);column(inline)"` -} - -func NewInLineOneToOne() *InLineOneToOne { - return new(InLineOneToOne) -} - -type IntegerPk struct { - ID int64 `orm:"pk"` - Value string -} - -type UintPk struct { - ID uint32 `orm:"pk"` - Name string -} - -type PtrPk struct { - ID *IntegerPk `orm:"pk;rel(one)"` - Positive bool -} - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -var ( - IsMysql = DBARGS.Driver == "mysql" - IsSqlite = DBARGS.Driver == "sqlite3" - IsPostgres = DBARGS.Driver == "postgres" - IsTidb = DBARGS.Driver == "tidb" -) - -var ( - dORM Ormer - dDbBaser dbBaser -) - -var ( - helpinfo = `need driver and source! - - Default DB Drivers. - - driver: url - mysql: https://github.com/go-sql-driver/mysql - sqlite3: https://github.com/mattn/go-sqlite3 - postgres: https://github.com/lib/pq - tidb: https://github.com/pingcap/tidb - - usage: - - go get -u github.com/astaxie/beego/orm - go get -u github.com/go-sql-driver/mysql - go get -u github.com/mattn/go-sqlite3 - go get -u github.com/lib/pq - go get -u github.com/pingcap/tidb - - #### MySQL - mysql -u root -e 'create database orm_test;' - export ORM_DRIVER=mysql - export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm - - - #### Sqlite3 - export ORM_DRIVER=sqlite3 - export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm - - - #### PostgreSQL - psql -c 'create database orm_test;' -U postgres - export ORM_DRIVER=postgres - export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm - - #### TiDB - export ORM_DRIVER=tidb - export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm - - ` -) - -func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() - - if DBARGS.Driver == "" || DBARGS.Source == "" { - fmt.Println(helpinfo) - os.Exit(2) - } - - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) - - alias := getDbAlias("default") - if alias.Driver == DRMySQL { - alias.Engine = "INNODB" - } - -} diff --git a/orm/orm_test.go b/orm/orm_test.go deleted file mode 100644 index bdb430b6..00000000 --- a/orm/orm_test.go +++ /dev/null @@ -1,2494 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -package orm - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - "testing" - "time" -) - -var _ = os.PathSeparator - -var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" -) - -type argAny []interface{} - -// get interface by index from interface slice -func (a argAny) Get(i int, args ...interface{}) (r interface{}) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { - if len(args) == 0 { - return false, fmt.Errorf("miss args") - } - b := args[0] - arg := argAny(args) - - switch v := a.(type) { - case reflect.Kind: - ok = reflect.ValueOf(b).Kind() == v - case time.Time: - if v2, vo := b.(time.Time); vo { - if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) - a = v.Format(format) - b = v2.Format(format) - ok = a == b - } else { - err = fmt.Errorf("compare datetime miss format") - goto wrongArg - } - } - default: - ok = ToStr(a) == ToStr(b) - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } - } - -wrongArg: - if err != nil { - return false, err - } - - return true, nil -} - -func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); !ok { - return err - } - return nil -} - -func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); !ok { - return err - } - return nil -} - -func getCaller(skip int) string { - pc, file, line, _ := runtime.Caller(skip) - fun := runtime.FuncForPC(pc) - _, fn := filepath.Split(file) - data, err := ioutil.ReadFile(file) - var codes []string - if err == nil { - lines := bytes.Split(data, []byte{'\n'}) - n := 10 - for i := 0; i < n; i++ { - o := line - n - if o < 0 { - continue - } - cur := o + i + 1 - flag := " " - if cur == line { - flag = ">>" - } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) - if code != "" { - codes = append(codes, code) - } - } - } - funName := fun.Name() - if i := strings.LastIndex(funName, "."); i > -1 { - funName = funName[i+1:] - } - return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) -} - -func throwFail(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.Fail() - } -} - -func throwFailNow(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.FailNow() - } -} - -func TestGetDB(t *testing.T) { - if db, err := GetDB(); err != nil { - throwFailNow(t, err) - } else { - err = db.Ping() - throwFailNow(t, err) - } -} - -func TestSyncDb(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - err := RunSyncdb("default", true, Debug) - throwFail(t, err) - - modelCache.clean() -} - -func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - BootStrap() - - dORM = NewOrm() - dDbBaser = getDbAlias("default").DbBaser -} - -func TestModelSyntax(t *testing.T) { - user := &User{} - ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) - throwFail(t, AssertIs(ok, true)) - - mi, ok = modelCache.get("user") - throwFail(t, AssertIs(ok, true)) - if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) - } -} - -var DataValues = map[string]interface{}{ - "Boolean": true, - "Char": "char", - "Text": "text", - "JSON": `{"name":"json"}`, - "Jsonb": `{"name": "jsonb"}`, - "Time": time.Now(), - "Date": time.Now(), - "DateTime": time.Now(), - "Byte": byte(1<<8 - 1), - "Rune": rune(1<<31 - 1), - "Int": int(1<<31 - 1), - "Int8": int8(1<<7 - 1), - "Int16": int16(1<<15 - 1), - "Int32": int32(1<<31 - 1), - "Int64": int64(1<<63 - 1), - "Uint": uint(1<<32 - 1), - "Uint8": uint8(1<<8 - 1), - "Uint16": uint16(1<<16 - 1), - "Uint32": uint32(1<<32 - 1), - "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported - "Float32": float32(100.1234), - "Float64": float64(100.1234), - "Decimal": float64(100.1234), -} - -func TestDataTypes(t *testing.T) { - d := Data{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - if name == "JSON" { - continue - } - e := ind.FieldByName(name) - e.Set(reflect.ValueOf(value)) - } - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = Data{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestNullDataTypes(t *testing.T) { - d := DataNull{} - - if IsPostgres { - // can removed when this fixed - // https://github.com/lib/pq/pull/125 - d.DateTime = time.Now() - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` - d = DataNull{ID: 1, JSON: data} - num, err := dORM.Update(&d) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - d = DataNull{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.JSON, data)) - - throwFail(t, AssertIs(d.NullBool.Valid, false)) - throwFail(t, AssertIs(d.NullString.Valid, false)) - throwFail(t, AssertIs(d.NullInt64.Valid, false)) - throwFail(t, AssertIs(d.NullFloat64.Valid, false)) - - throwFail(t, AssertIs(d.BooleanPtr, nil)) - throwFail(t, AssertIs(d.CharPtr, nil)) - throwFail(t, AssertIs(d.TextPtr, nil)) - throwFail(t, AssertIs(d.BytePtr, nil)) - throwFail(t, AssertIs(d.RunePtr, nil)) - throwFail(t, AssertIs(d.IntPtr, nil)) - throwFail(t, AssertIs(d.Int8Ptr, nil)) - throwFail(t, AssertIs(d.Int16Ptr, nil)) - throwFail(t, AssertIs(d.Int32Ptr, nil)) - throwFail(t, AssertIs(d.Int64Ptr, nil)) - throwFail(t, AssertIs(d.UintPtr, nil)) - throwFail(t, AssertIs(d.Uint8Ptr, nil)) - throwFail(t, AssertIs(d.Uint16Ptr, nil)) - throwFail(t, AssertIs(d.Uint32Ptr, nil)) - throwFail(t, AssertIs(d.Uint64Ptr, nil)) - throwFail(t, AssertIs(d.Float32Ptr, nil)) - throwFail(t, AssertIs(d.Float64Ptr, nil)) - throwFail(t, AssertIs(d.DecimalPtr, nil)) - throwFail(t, AssertIs(d.TimePtr, nil)) - throwFail(t, AssertIs(d.DatePtr, nil)) - throwFail(t, AssertIs(d.DateTimePtr, nil)) - - _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() - throwFail(t, err) - - d = DataNull{ID: 2} - err = dORM.Read(&d) - throwFail(t, err) - - booleanPtr := true - charPtr := string("test") - textPtr := string("test") - bytePtr := byte('t') - runePtr := rune('t') - intPtr := int(42) - int8Ptr := int8(42) - int16Ptr := int16(42) - int32Ptr := int32(42) - int64Ptr := int64(42) - uintPtr := uint(42) - uint8Ptr := uint8(42) - uint16Ptr := uint16(42) - uint32Ptr := uint32(42) - uint64Ptr := uint64(42) - float32Ptr := float32(42.0) - float64Ptr := float64(42.0) - decimalPtr := float64(42.0) - timePtr := time.Now() - datePtr := time.Now() - dateTimePtr := time.Now() - - d = DataNull{ - DateTime: time.Now(), - NullString: sql.NullString{String: "test", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - BooleanPtr: &booleanPtr, - CharPtr: &charPtr, - TextPtr: &textPtr, - BytePtr: &bytePtr, - RunePtr: &runePtr, - IntPtr: &intPtr, - Int8Ptr: &int8Ptr, - Int16Ptr: &int16Ptr, - Int32Ptr: &int32Ptr, - Int64Ptr: &int64Ptr, - UintPtr: &uintPtr, - Uint8Ptr: &uint8Ptr, - Uint16Ptr: &uint16Ptr, - Uint32Ptr: &uint32Ptr, - Uint64Ptr: &uint64Ptr, - Float32Ptr: &float32Ptr, - Float64Ptr: &float64Ptr, - DecimalPtr: &decimalPtr, - TimePtr: &timePtr, - DatePtr: &datePtr, - DateTimePtr: &dateTimePtr, - } - - id, err = dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - d = DataNull{ID: 3} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.NullBool.Valid, true)) - throwFail(t, AssertIs(d.NullBool.Bool, true)) - - throwFail(t, AssertIs(d.NullString.Valid, true)) - throwFail(t, AssertIs(d.NullString.String, "test")) - - throwFail(t, AssertIs(d.NullInt64.Valid, true)) - throwFail(t, AssertIs(d.NullInt64.Int64, 42)) - - throwFail(t, AssertIs(d.NullFloat64.Valid, true)) - throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) - - throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) - throwFail(t, AssertIs(*d.CharPtr, charPtr)) - throwFail(t, AssertIs(*d.TextPtr, textPtr)) - throwFail(t, AssertIs(*d.BytePtr, bytePtr)) - throwFail(t, AssertIs(*d.RunePtr, runePtr)) - throwFail(t, AssertIs(*d.IntPtr, intPtr)) - throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) - throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) - throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) - throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) - throwFail(t, AssertIs(*d.UintPtr, uintPtr)) - throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) - throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) - throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) - throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) - throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) - throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) - throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) - - // test support for pointer fields using RawSeter.QueryRows() - var dnList []*DataNull - Q := dDbBaser.TableQuote() - num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - equal := reflect.DeepEqual(*dnList[0], d) - throwFailNow(t, AssertIs(equal, true)) -} - -func TestDataCustomTypes(t *testing.T) { - d := DataCustom{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - e.Set(reflect.ValueOf(value).Convert(e.Type())) - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = DataCustom{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - vu := e.Interface() - value = reflect.ValueOf(value).Convert(e.Type()).Interface() - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestCRUD(t *testing.T) { - profile := NewProfile() - profile.Age = 30 - profile.Money = 1234.12 - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 3 - user.IsStaff = true - user.IsActive = true - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - u := &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - - throwFail(t, AssertIs(u.UserName, "slene")) - throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) - throwFail(t, AssertIs(u.Password, "pass")) - throwFail(t, AssertIs(u.Status, 3)) - throwFail(t, AssertIs(u.IsStaff, true)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) - - user.UserName = "astaxie" - user.Profile = profile - num, err := dORM.Update(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "astaxie")) - throwFail(t, AssertIs(u.Profile.ID, profile.ID)) - - u = &User{UserName: "astaxie", Password: "pass"} - err = dORM.Read(u, "UserName") - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, 1)) - - u.UserName = "QQ" - u.Password = "111" - num, err = dORM.Update(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "QQ")) - throwFail(t, AssertIs(u.Password, "pass")) - - num, err = dORM.Delete(profile) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - throwFail(t, AssertIs(true, u.Profile == nil)) - - num, err = dORM.Delete(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: 100} - err = dORM.Read(u) - throwFail(t, AssertIs(err, ErrNoRows)) - - ub := UserBig{} - ub.Name = "name" - id, err = dORM.Insert(&ub) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - ub = UserBig{ID: 1} - err = dORM.Read(&ub) - throwFail(t, err) - throwFail(t, AssertIs(ub.Name, "name")) - - num, err = dORM.Delete(&ub, "name") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertTestData(t *testing.T) { - var users []*User - - profile := NewProfile() - profile.Age = 28 - profile.Money = 1234.12 - - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 1 - user.IsStaff = false - user.IsActive = true - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - profile = NewProfile() - profile.Age = 30 - profile.Money = 4321.09 - - id, err = dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "astaxie" - user.Email = "astaxie@gmail.com" - user.Password = "password" - user.Status = 2 - user.IsStaff = true - user.IsActive = false - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "nobody" - user.Email = "nobody@gmail.com" - user.Password = "nobody" - user.Status = 3 - user.IsStaff = false - user.IsActive = false - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 4)) - - tags := []*Tag{ - {Name: "golang", BestPost: &Post{ID: 2}}, - {Name: "example"}, - {Name: "format"}, - {Name: "c++"}, - } - - posts := []*Post{ - {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. -This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. -With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, - {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. -The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, - } - - comments := []*Comment{ - {Post: posts[0], Content: "a comment"}, - {Post: posts[1], Content: "yes"}, - {Post: posts[1]}, - {Post: posts[1]}, - {Post: posts[2]}, - {Post: posts[2]}, - } - - for _, tag := range tags { - id, err := dORM.Insert(tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, post := range posts { - id, err := dORM.Insert(post) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(post.Tags) - if num > 0 { - nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - - for _, comment := range comments { - id, err := dORM.Insert(comment) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - permissions := []*Permission{ - {Name: "writePosts"}, - {Name: "readComments"}, - {Name: "readPosts"}, - } - - groups := []*Group{ - { - Name: "admins", - Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, - }, - { - Name: "users", - Permissions: []*Permission{permissions[1], permissions[2]}, - }, - } - - for _, permission := range permissions { - id, err := dORM.Insert(permission) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, group := range groups { - _, err := dORM.Insert(group) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(group.Permissions) - if num > 0 { - nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - -} - -func TestCustomField(t *testing.T) { - user := User{ID: 2} - err := dORM.Read(&user) - throwFailNow(t, err) - - user.Langs = append(user.Langs, "zh-CN", "en-US") - user.Extra.Name = "beego" - user.Extra.Data = "orm" - _, err = dORM.Update(&user, "Langs", "Extra") - throwFailNow(t, err) - - user = User{ID: 2} - err = dORM.Read(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(len(user.Langs), 2)) - throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) - throwFailNow(t, AssertIs(user.Langs[1], "en-US")) - - throwFailNow(t, AssertIs(user.Extra.Name, "beego")) - throwFailNow(t, AssertIs(user.Extra.Data, "orm")) -} - -func TestExpr(t *testing.T) { - user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") - num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("created", time.Now()).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() - // throwFail(t, err) - // throwFail(t, AssertIs(num, 3)) -} - -func TestOperators(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", String("slene")).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__iexact", "Slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__contains", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - var shouldNum int - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__contains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gt", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gte", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("status__lt", Uint(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__lte", Int(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("user_name__startswith", "s").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - if IsSqlite || IsTidb { - shouldNum = 1 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__startswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__istartswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__endswith", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__endswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__iendswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("status__in", 1, 2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__in", []int{1, 2}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - n1, n2 := 1, 2 - num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", 2, 3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", []int{2, 3}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("user_name", "= 'slene'").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.FilterRaw("status", "IN (1, 2)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSetCond(t *testing.T) { - cond := NewCondition() - cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - - qs := dORM.QueryTable("user") - num, err := qs.SetCond(cond1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond3 := cond.AndNotCond(cond.And("status__in", 1)) - num, err = qs.SetCond(cond3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond4).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond5).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) -} - -func TestLimit(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Limit(-1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - num, err = qs.Limit(-1, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Limit(0, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOffset(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOrderBy(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestAll(t *testing.T) { - var users []*User - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("Id").All(&users) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFail(t, AssertIs(users[0].UserName, "slene")) - throwFail(t, AssertIs(users[1].UserName, "astaxie")) - throwFail(t, AssertIs(users[2].UserName, "nobody")) - - var users2 []User - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").All(&users2) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(users2), 3)) - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - throwFailNow(t, AssertIs(users2[0].ID, 0)) - throwFailNow(t, AssertIs(users2[1].ID, 0)) - throwFailNow(t, AssertIs(users2[2].ID, 0)) - throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - var users3 []*User - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - throwFailNow(t, AssertIs(users3 == nil, false)) -} - -func TestOne(t *testing.T) { - var user User - qs := dORM.QueryTable("user") - err := qs.One(&user) - throwFail(t, err) - - user = User{} - err = qs.OrderBy("Id").Limit(1).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "slene")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - user = User{} - err = qs.OrderBy("-Id").Limit(100).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "nobody")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - err = qs.Filter("user_name", "nothing").One(&user) - throwFail(t, AssertIs(err, ErrNoRows)) - -} - -func TestValues(t *testing.T) { - var maps []Params - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[2]["Profile"], nil)) - } - - num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) - throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) - } - - num, err = qs.Filter("UserName", "slene").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestValuesList(t *testing.T) { - var list []ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").ValuesList(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) - } - - num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][0], "slene")) - throwFail(t, AssertIs(list[0][1], 28)) - throwFail(t, AssertIs(list[2][1], nil)) - } -} - -func TestValuesFlat(t *testing.T) { - var list ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "slene")) - throwFail(t, AssertIs(list[1], "astaxie")) - throwFail(t, AssertIs(list[2], "nobody")) - } -} - -func TestRelatedSel(t *testing.T) { - if IsTidb { - // Skip it. TiDB does not support relation now. - return - } - qs := dORM.QueryTable("user") - num, err := qs.Filter("profile__age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var user User - err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "slene").RelatedSel().One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(user.Profile, nil)) - - qs = dORM.QueryTable("user_profile") - num, err = qs.Filter("user__username", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var posts []*Post - qs = dORM.QueryTable("post") - num, err = qs.RelatedSel().All(&posts) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 4)) - - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) -} - -func TestReverseQuery(t *testing.T) { - var profile Profile - err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - profile = Profile{} - err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - var user User - err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - var posts []*Post - num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). - Filter("User__UserName", "slene").RelatedSel().All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].User == nil, false)) - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - - var tags []*Tag - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) - throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) -} - -func TestLoadRelated(t *testing.T) { - // load reverse foreign key - user := User{ID: 3} - - err := dORM.Read(&user) - throwFailNow(t, err) - - num, err := dORM.LoadRelated(&user, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - - num, err = dORM.LoadRelated(&user, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - // load reverse one to one - profile := Profile{ID: 3} - profile.BestPost = &Post{ID: 2} - num, err = dORM.Update(&profile, "BestPost") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - err = dORM.Read(&profile) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&profile, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&profile, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) - - // load rel one to one - err = dORM.Read(&user) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&user, "Profile") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - num, err = dORM.LoadRelated(&user, "Profile", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) - throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) - - post := Post{ID: 2} - - // load rel foreign key - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&post, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(post.User.Profile == nil, false)) - throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) - - // load rel m2m - post = Post{ID: 2} - - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "Tags") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - - num, err = dORM.LoadRelated(&post, "Tags", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) - - // load reverse m2m - tag := Tag{ID: 1} - - err = dORM.Read(&tag) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&tag, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - - num, err = dORM.LoadRelated(&tag, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) -} - -func TestQueryM2M(t *testing.T) { - post := Post{ID: 4} - m2m := dORM.QueryM2M(&post, "Tags") - - tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} - tag2 := &Tag{Name: "TestTag3"} - tag3 := []interface{}{&Tag{Name: "TestTag4"}} - - tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} - - for _, tag := range tags { - _, err := dORM.Insert(tag) - throwFailNow(t, err) - } - - num, err := m2m.Add(tag1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 5)) - - num, err = m2m.Remove(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - exist := m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - tag := Tag{Name: "test"} - _, err = dORM.Insert(&tag) - throwFailNow(t, err) - - m2m = dORM.QueryM2M(&tag, "Posts") - - post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} - post2 := &Post{Title: "TestPost3"} - post3 := []interface{}{&Post{Title: "TestPost4"}} - - posts := []interface{}{post1[0], post1[1], post2, post3[0]} - - for _, post := range posts { - p := post.(*Post) - p.User = &User{ID: 1} - _, err := dORM.Insert(post) - throwFailNow(t, err) - } - - num, err = m2m.Add(post1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - num, err = m2m.Remove(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - num, err = dORM.Delete(&tag) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) -} - -func TestQueryRelate(t *testing.T) { - // post := &Post{Id: 2} - - // qs := dORM.QueryRelate(post, "Tags") - // num, err := qs.Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - - // var tags []*Tag - // num, err = qs.All(&tags) - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - // throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) -} - -func TestPkManyRelated(t *testing.T) { - permission := &Permission{Name: "readPosts"} - err := dORM.Read(permission, "Name") - throwFailNow(t, err) - - var groups []*Group - qs := dORM.QueryTable("Group") - num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) -} - -func TestPrepareInsert(t *testing.T) { - qs := dORM.QueryTable("user") - i, err := qs.PrepareInsert() - throwFailNow(t, err) - - var user User - user.UserName = "testing1" - num, err := i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - user.UserName = "testing2" - num, err = i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - err = i.Close() - throwFail(t, err) - err = i.Close() - throwFail(t, AssertIs(err, ErrStmtClosed)) -} - -func TestRawExec(t *testing.T) { - Q := dDbBaser.TableQuote() - - query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) - res, err := dORM.Raw(query, "testing", "slene").Exec() - throwFail(t, err) - num, err := res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) - - res, err = dORM.Raw(query, "slene", "testing").Exec() - throwFail(t, err) - num, err = res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) -} - -func TestRawQueryRow(t *testing.T) { - var ( - Boolean bool - Char string - Text string - Time time.Time - Date time.Time - DateTime time.Time - Byte byte - Rune rune - Int int - Int8 int - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 - ) - - dataValues := make(map[string]interface{}, len(DataValues)) - - for k, v := range DataValues { - dataValues[strings.ToLower(k)] = v - } - - Q := dDbBaser.TableQuote() - - cols := []string{ - "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - var id int - values := []interface{}{ - &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, - &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, - } - err := dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - for i, col := range cols { - vu := values[i] - v := reflect.ValueOf(vu).Elem().Interface() - switch col { - case "id": - throwFail(t, AssertIs(id, 1)) - case "time": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) - case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) - default: - throwFail(t, AssertIs(v, dataValues[col])) - } - } - - var ( - uid int - status *int - pid *int - ) - - cols = []string{ - "id", "Status", "profile_id", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) - err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) - throwFail(t, err) - throwFail(t, AssertIs(uid, 4)) - throwFail(t, AssertIs(*status, 3)) - throwFail(t, AssertIs(pid, nil)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nd *DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - err = dORM.Raw(query, newId).QueryRow(&nd) - throwFailNow(t, err) - - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -// user_profile table -type userProfile struct { - User - Age int - Money float64 -} - -func TestQueryRows(t *testing.T) { - Q := dDbBaser.TableQuote() - - var datas []*Data - - query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas), 1)) - - ind := reflect.Indirect(reflect.ValueOf(datas[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var datas2 []Data - - query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err = dORM.Raw(query).QueryRows(&datas2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas2), 1)) - - ind = reflect.Indirect(reflect.ValueOf(datas2[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var ids []int - var usernames []string - query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &usernames) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(ids), 3)) - throwFailNow(t, AssertIs(ids[0], 2)) - throwFailNow(t, AssertIs(usernames[0], "slene")) - throwFailNow(t, AssertIs(ids[1], 3)) - throwFailNow(t, AssertIs(usernames[1], "astaxie")) - throwFailNow(t, AssertIs(ids[2], 4)) - throwFailNow(t, AssertIs(usernames[2], "nobody")) - - //test query rows by nested struct - var l []userProfile - query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&l) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(l), 2)) - throwFailNow(t, AssertIs(l[0].UserName, "slene")) - throwFailNow(t, AssertIs(l[0].Age, 28)) - throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(l[1].Age, 30)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nDataList []*DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - num, err = dORM.Raw(query, newId).QueryRows(&nDataList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - nd := nDataList[0] - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -func TestRawValues(t *testing.T) { - Q := dDbBaser.TableQuote() - - var maps []Params - query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) - num, err := dORM.Raw(query, 1).Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(maps[0]["user_name"], "slene")) - } - - var lists []ParamsList - num, err = dORM.Raw(query, 1).ValuesList(&lists) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(lists[0][0], "slene")) - } - - query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) - var list ParamsList - num, err = dORM.Raw(query).ValuesFlat(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "2")) - throwFail(t, AssertIs(list[1], "3")) - throwFail(t, AssertIs(list[2], nil)) - } -} - -func TestRawPrepare(t *testing.T) { - switch { - case IsMysql || IsSqlite: - - pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - throwFail(t, err) - if pre != nil { - r, err := pre.Exec("name1") - throwFail(t, err) - - tid, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(tid > 0, true)) - - r, err = pre.Exec("name2") - throwFail(t, err) - - id, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+1)) - - r, err = pre.Exec("name3") - throwFail(t, err) - - id, err = r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+2)) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - - case IsPostgres: - - pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() - throwFail(t, err) - if pre != nil { - _, err := pre.Exec("name1") - throwFail(t, err) - - _, err = pre.Exec("name2") - throwFail(t, err) - - _, err = pre.Exec("name3") - throwFail(t, err) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - if err == nil { - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - } - } -} - -func TestUpdate(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ - "is_staff": true, - "is_active": true, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - // with join - num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ - "is_staff": false, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColAdd, 100), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMinus, 50), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMultiply, 3), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColExcept, 5), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - user := User{UserName: "slene"} - err = dORM.Read(&user, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(user.Nums, 30)) -} - -func TestDelete(t *testing.T) { - qs := dORM.QueryTable("user_profile") - num, err := qs.Filter("user__user_name", "slene").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 6)) - - qs = dORM.QueryTable("post") - num, err = qs.Filter("Id", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - qs = dORM.QueryTable("comment") - num, err = qs.Filter("Post__User", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestTransaction(t *testing.T) { - // this test worked when database support transaction - - o := NewOrm() - err := o.Begin() - throwFail(t, err) - - var names = []string{"1", "2", "3"} - - var tag Tag - tag.Name = names[0] - id, err := o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - switch { - case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() - throwFail(t, err) - if err == nil { - id, err = res.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - } - - err = o.Rollback() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - err = o.Begin() - throwFail(t, err) - - tag.Name = "commit" - id, err = o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - o.Commit() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - -} - -func TestTransactionIsolationLevel(t *testing.T) { - // this test worked when database support transaction isolation level - if IsSqlite { - return - } - - o1 := NewOrm() - o2 := NewOrm() - - // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - - // o1 insert tag - var tag Tag - tag.Name = "test-transaction" - id, err := o1.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o1 commit - o1.Commit() - - // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o2 commit and query tag table, get the result - o2.Commit() - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestBeginTxWithContextCanceled(t *testing.T) { - o := NewOrm() - ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // cancel the context before commit to make it error - cancel() - err = o.Commit() - throwFail(t, AssertIs(err, context.Canceled)) -} - -func TestReadOrCreate(t *testing.T) { - u := &User{ - UserName: "Kyle", - Email: "kylemcc@gmail.com", - Password: "other_pass", - Status: 7, - IsStaff: false, - IsActive: true, - } - - created, pk, err := dORM.ReadOrCreate(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.ID, pk)) - throwFail(t, AssertIs(u.UserName, "Kyle")) - throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) - throwFail(t, AssertIs(u.Password, "other_pass")) - throwFail(t, AssertIs(u.Status, 7)) - throwFail(t, AssertIs(u.IsStaff, false)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) - - nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} - created, pk, err = dORM.ReadOrCreate(nu, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.UserName, u.UserName)) - throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above - throwFail(t, AssertIs(nu.Password, u.Password)) - throwFail(t, AssertIs(nu.Status, u.Status)) - throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) - throwFail(t, AssertIs(nu.IsActive, u.IsActive)) - - dORM.Delete(u) -} - -func TestInLine(t *testing.T) { - name := "inline" - email := "hello@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - il := NewInLine() - il.ID = 1 - err = dORM.Read(il) - throwFail(t, err) - - throwFail(t, AssertIs(il.Name, name)) - throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) -} - -func TestInLineOneToOne(t *testing.T) { - name := "121" - email := "121@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - note := "one2one" - il121 := NewInLineOneToOne() - il121.Note = note - il121.InLine = inline - _, err = dORM.Insert(il121) - throwFail(t, err) - throwFail(t, AssertIs(il121.ID, 1)) - - il := NewInLineOneToOne() - err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) - - throwFail(t, err) - throwFail(t, AssertIs(il.Note, note)) - throwFail(t, AssertIs(il.InLine.ID, id)) - throwFail(t, AssertIs(il.InLine.Name, name)) - throwFail(t, AssertIs(il.InLine.Email, email)) - - rinline := NewInLine() - err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) - - throwFail(t, err) - throwFail(t, AssertIs(rinline.ID, id)) - throwFail(t, AssertIs(rinline.Name, name)) - throwFail(t, AssertIs(rinline.Email, email)) -} - -func TestIntegerPk(t *testing.T) { - its := []IntegerPk{ - {ID: math.MinInt64, Value: "-"}, - {ID: 0, Value: "0"}, - {ID: math.MaxInt64, Value: "+"}, - } - - num, err := dORM.InsertMulti(len(its), its) - throwFail(t, err) - throwFail(t, AssertIs(num, len(its))) - - for _, intPk := range its { - out := IntegerPk{ID: intPk.ID} - err = dORM.Read(&out) - throwFail(t, err) - throwFail(t, AssertIs(out.Value, intPk.Value)) - } - - num, err = dORM.InsertMulti(1, []*IntegerPk{{ - ID: 1, Value: "ok", - }}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertAuto(t *testing.T) { - u := &User{ - UserName: "autoPre", - Email: "autoPre@gmail.com", - } - - id, err := dORM.Insert(u) - throwFail(t, err) - - id += 100 - su := &User{ - ID: int(id), - UserName: "auto", - Email: "auto@gmail.com", - } - - nid, err := dORM.Insert(su) - throwFail(t, err) - throwFail(t, AssertIs(nid, id)) - - users := []User{ - {ID: int(id + 100), UserName: "auto_100"}, - {ID: int(id + 110), UserName: "auto_110"}, - {ID: int(id + 120), UserName: "auto_120"}, - } - num, err := dORM.InsertMulti(100, users) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - u = &User{ - UserName: "auto_121", - } - - nid, err = dORM.Insert(u) - throwFail(t, err) - throwFail(t, AssertIs(nid, id+120+1)) -} - -func TestUintPk(t *testing.T) { - name := "go" - u := &UintPk{ - ID: 8, - Name: name, - } - - created, _, err := dORM.ReadOrCreate(u, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.Name, name)) - - nu := &UintPk{ID: 8} - created, pk, err := dORM.ReadOrCreate(nu, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.Name, name)) - - dORM.Delete(u) -} - -func TestPtrPk(t *testing.T) { - parent := &IntegerPk{ID: 10, Value: "10"} - - id, _ := dORM.Insert(parent) - if !IsMysql { - // MySql does not support last_insert_id in this case: see #2382 - throwFail(t, AssertIs(id, 10)) - } - - ptr := PtrPk{ID: parent, Positive: true} - num, err := dORM.InsertMulti(2, []PtrPk{ptr}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(ptr.ID, parent)) - - nptr := &PtrPk{ID: parent} - created, pk, err := dORM.ReadOrCreate(nptr, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, true)) - - nptr = &PtrPk{Positive: true} - created, pk, err = dORM.ReadOrCreate(nptr, "Positive") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - - nptr.Positive = false - num, err = dORM.Update(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, false)) - - num, err = dORM.Delete(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSnake(t *testing.T) { - cases := map[string]string{ - "i": "i", - "I": "i", - "iD": "i_d", - "ID": "i_d", - "NO": "n_o", - "NOO": "n_o_o", - "NOOooOOoo": "n_o_ooo_o_ooo", - "OrderNO": "order_n_o", - "tagName": "tag_name", - "tag_Name": "tag__name", - "tag_name": "tag_name", - "_tag_name": "_tag_name", - "tag_666name": "tag_666name", - "tag_666Name": "tag_666_name", - } - for name, want := range cases { - got := snakeString(name) - throwFail(t, AssertIs(got, want)) - } -} - -func TestIgnoreCaseTag(t *testing.T) { - type testTagModel struct { - ID int `orm:"pk"` - NOO string `orm:"column(n)"` - Name01 string `orm:"NULL"` - Name02 string `orm:"COLUMN(Name)"` - Name03 string `orm:"Column(name)"` - } - modelCache.clean() - RegisterModel(&testTagModel{}) - info, ok := modelCache.get("test_tag_model") - throwFail(t, AssertIs(ok, true)) - throwFail(t, AssertNot(info, nil)) - if t == nil { - return - } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) -} - -func TestInsertOrUpdate(t *testing.T) { - RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} - dORM.Insert(&user) - test := User{UserName: "unique_username133"} - fmt.Println(dORM.Driver().Name()) - if dORM.Driver().Name() == "sqlite3" { - fmt.Println("sqlite3 is nonsupport") - return - } - //test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) - } - //test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) - } - - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { - return - } - //test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) - } - //test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) - } - //test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) - } - //test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) - } -} diff --git a/orm/utils_test.go b/orm/utils_test.go deleted file mode 100644 index 7d94cada..00000000 --- a/orm/utils_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "testing" -) - -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} - - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } - - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeString(t *testing.T) { - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeStringWithAcronym(t *testing.T) { - camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index e3a635f2..79e926d3 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -293,7 +293,7 @@ type Post struct { Content string `orm:"type(text)"` Created time.Time `orm:"auto_now_add"` Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { @@ -351,7 +351,7 @@ type Group struct { type Permission struct { ID int `orm:"column(id)"` Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.GroupPermissions)"` } type GroupPermissions struct { @@ -446,7 +446,7 @@ var ( usage: - go get -u github.com/astaxie/beego/orm + go get -u github.com/astaxie/beego/pkg/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq @@ -456,25 +456,25 @@ var ( mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pkg/orm #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm + go test -v github.com/astaxie/beego/pgk/orm ` ) @@ -487,7 +487,11 @@ func init() { os.Exit(2) } - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + if err != nil{ + panic(fmt.Sprintf("can not register database: %v", err)) + } alias := getDbAlias("default") if alias.Driver == DRMySQL { diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 0551b1cd..bd2cb783 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/astaxie/beego/orm" +// "github.com/astaxie/beego/pkg/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index bdb430b6..eac7b33a 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -30,6 +30,8 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) var _ = os.PathSeparator @@ -141,6 +143,7 @@ func getCaller(skip int) string { return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } +// Deprecated: Using stretchr/testify/assert func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -455,9 +458,11 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() + assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) + assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) + assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) // test support for pointer fields using RawSeter.QueryRows() var dnList []*DataNull @@ -532,8 +537,9 @@ func TestCRUD(t *testing.T) { throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) user.UserName = "astaxie" user.Profile = profile @@ -1793,7 +1799,7 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(usernames[2], "nobody")) - //test query rows by nested struct + // test query rows by nested struct var l []userProfile query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) num, err = dORM.Raw(query).QueryRows(&l) @@ -2413,7 +2419,7 @@ func TestInsertOrUpdate(t *testing.T) { fmt.Println("sqlite3 is nonsupport") return } - //test1 + // test1 _, err := dORM.InsertOrUpdate(&user1, "user_name") if err != nil { fmt.Println(err) @@ -2425,7 +2431,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(user1.Status, test.Status)) } - //test2 + // test2 _, err = dORM.InsertOrUpdate(&user2, "user_name") if err != nil { fmt.Println(err) @@ -2439,11 +2445,11 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) } - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values if IsPostgres { return } - //test3 + + // test3 + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") if err != nil { fmt.Println(err) @@ -2455,7 +2461,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(user2.Status+1, test.Status)) } - //test4 - + // test4 - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") if err != nil { fmt.Println(err) @@ -2467,7 +2473,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) } - //test5 * + // test5 * _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") if err != nil { fmt.Println(err) @@ -2479,7 +2485,7 @@ func TestInsertOrUpdate(t *testing.T) { dORM.Read(&test, "user_name") throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) } - //test6 / + // test6 / _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") if err != nil { fmt.Println(err) From 7258ef113a0daa8478dc86afd41859dc8e659239 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 19 Jul 2020 14:34:57 +0000 Subject: [PATCH 029/207] Store nearest error info --- toolbox/task.go | 12 +++++++++--- toolbox/task_test.go | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/toolbox/task.go b/toolbox/task.go index d2a94ba9..c902fdfc 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -102,6 +102,8 @@ type taskerr struct { } // Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList type Task struct { Taskname string Spec *Schedule @@ -111,6 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func @@ -119,8 +122,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { task := &Task{ Taskname: tname, DoFunc: f, + // Make configurable ErrLimit: 100, SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), } task.SetCron(spec) return task @@ -144,9 +150,9 @@ func (t *Task) GetStatus() string { func (t *Task) Run() error { err := t.DoFunc() if err != nil { - if t.ErrLimit > 0 && t.ErrLimit > len(t.Errlist) { - t.Errlist = append(t.Errlist, &taskerr{t: t.Next, errinfo: err.Error()}) - } + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ } return err } diff --git a/toolbox/task_test.go b/toolbox/task_test.go index 596bc9c5..3a4cce2f 100644 --- a/toolbox/task_test.go +++ b/toolbox/task_test.go @@ -15,10 +15,13 @@ package toolbox import ( + "errors" "fmt" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestParse(t *testing.T) { @@ -53,6 +56,25 @@ func TestSpec(t *testing.T) { } } +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func() error { + cnt ++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200 ; i ++ { + e := tk.Run() + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { From 32da446eb1d8785b70037d648ca74261cb20fd2f Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 19 Jul 2020 23:46:42 +0800 Subject: [PATCH 030/207] refactor orm --- pkg/orm/db_alias.go | 53 +++++++ pkg/orm/orm.go | 307 +++++++++++++++++++++++++--------------- pkg/orm/orm_object.go | 4 +- pkg/orm/orm_querym2m.go | 2 +- pkg/orm/orm_queryset.go | 4 +- pkg/orm/orm_raw.go | 4 +- pkg/orm/orm_test.go | 34 ++--- pkg/orm/types.go | 162 +++++++++++++-------- 8 files changed, 370 insertions(+), 200 deletions(-) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index bf6c350c..b2a72f56 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -111,6 +111,9 @@ type DB struct { stmtDecorators *lru.Cache } +var _ dbQuerier = new(DB) +var _ txer = new(DB) + func (d *DB) Begin() (*sql.Tx, error) { return d.DB.Begin() } @@ -220,6 +223,56 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac return stmt.QueryRowContext(ctx, args) } +type TxDB struct { + tx *sql.Tx +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Commit() error { + return t.tx.Commit() +} + +func (t *TxDB) Rollback() error { + return t.tx.Rollback() +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { + return t.PrepareContext(context.Background(),query) +} + +func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return t.tx.PrepareContext(ctx, query) +} + +func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.ExecContext(context.Background(), query, args...) +} + +func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return t.QueryContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { + return t.QueryRowContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + type alias struct { Name string Driver DriverType diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index bd2cb783..3db75751 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -62,6 +62,8 @@ import ( "reflect" "sync" "time" + + "github.com/astaxie/beego/logs" ) // DebugQueries define the debug @@ -76,8 +78,7 @@ var ( DefaultRowsLimit = -1 DefaultRelsDepth = 2 DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") + ErrTxDone = errors.New(" transaction already done") ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") @@ -91,16 +92,16 @@ type Params map[string]interface{} // ParamsList stores paramslist type ParamsList []interface{} -type orm struct { +type ormBase struct { alias *alias db dbQuerier - isTx bool } -var _ Ormer = new(orm) +var _ DQL = new(ormBase) +var _ DML = new(ormBase) // get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() @@ -115,7 +116,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect } // get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { +func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) @@ -124,33 +125,42 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } // read data to model -func (o *orm) Read(md interface{}, cols ...string) error { +func (o *ormBase) Read(md interface{}, cols ...string) error { + return o.ReadWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { +func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { + return o.ReadForUpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { +func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} +func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create - id, err := o.Insert(md) - return (err == nil), id, err + id, err := o.InsertWithCtx(ctx, md) + return err == nil, id, err } id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { id = int64(vid.Uint()) } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) } else { id = vid.Int() } @@ -159,7 +169,10 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i } // insert model data to database -func (o *orm) Insert(md interface{}) (int64, error) { +func (o *ormBase) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} +func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { @@ -172,7 +185,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { } // set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { +func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) @@ -183,7 +196,10 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } // insert some models to database -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { +func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { + return o.InsertMultiWithCtx(context.Background(), bulk, mds) +} +func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var cnt int64 sind := reflect.Indirect(reflect.ValueOf(mds)) @@ -218,7 +234,10 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { } // InsertOrUpdate data to database -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { +func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { + return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) +} +func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { @@ -232,14 +251,20 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // update model to database. // cols set the columns those want to update. -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { + return o.UpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) } // delete model in database // cols shows the delete conditions values read from. default is pk -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { + return o.DeleteWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) if err != nil { @@ -252,7 +277,10 @@ func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { } // create a models to models queryer -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { +func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { + return o.QueryM2MWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -274,7 +302,10 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + return o.LoadRelatedWithCtx(context.Background(), md, name, args...) +} +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) @@ -341,14 +372,17 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int // qs := orm.QueryRelated(post,"Tag") // qs.All(&[]*Tag{}) // -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { +func (o *ormBase) QueryRelated(md interface{}, name string) QuerySeter { + return o.QueryRelatedWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryRelatedWithCtx(ctx context.Context, md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } // get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -380,7 +414,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, } // get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: default: @@ -401,7 +435,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS } // get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: default: @@ -423,7 +457,10 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { +func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +} +func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -442,11 +479,136 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { return } -// switch to another registered database driver by given name. -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) +// return a raw query seter for raw sql string. +func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { + return o.RawWithCtx(context.Background(), query, args...) +} +func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *ormBase) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *ormBase) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats } + return nil +} + +type orm struct { + ormBase +} + +var _ Ormer = new(orm) + +func (o *orm) Begin() (TxOrmer, error) { + return o.BeginWithCtx(context.Background()) +} + +func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(ctx, nil) +} + +func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + _txOrm := &txOrm{ + ormBase: ormBase{ + alias: o.alias, + db: &TxDB{tx: tx}, + }, + isClosed: false, + } + + var taskTxOrm TxOrmer = _txOrm + return taskTxOrm, nil +} + +func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtx(context.Background(), task) +} + +func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + panicked := true + defer func() { + if panicked || err != nil { + e := _txOrm.Rollback() + logs.Error("rollback transaction failed: %v", e) + } else { + e := _txOrm.Commit() + logs.Error("commit transaction failed: %v", e) + } + }() + + var taskTxOrm = _txOrm + err = task(taskTxOrm) + panicked = false + return err +} + +type txOrm struct { + ormBase + isClosed bool + closeMutex sync.Mutex +} + +var _ TxOrmer = new(txOrm) + +func (t *txOrm) Commit() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Commit() +} + +func (t *txOrm) Rollback() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Rollback() +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + name := `default` if al, ok := dataBaseCache.get(name); ok { o.alias = al if Debug { @@ -455,92 +617,9 @@ func (o *orm) Using(name string) error { o.db = al.DB } } else { - return fmt.Errorf(" unknown db alias name `%s`", name) + panic(fmt.Errorf(" unknown db alias name `%s`", name)) } - return nil -} -// begin transaction -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.DB.Stats() - return &stats - } - return nil -} - -// NewOrm create new orm -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } return o } diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go index de3181ce..6f9798d3 100644 --- a/pkg/orm/orm_object.go +++ b/pkg/orm/orm_object.go @@ -22,7 +22,7 @@ import ( // an insert queryer struct type insertSet struct { mi *modelInfo - orm *orm + orm *ormBase stmt stmtQuerier closed bool } @@ -70,7 +70,7 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { +func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go index 6a270a0d..17e1b5d1 100644 --- a/pkg/orm/orm_querym2m.go +++ b/pkg/orm/orm_querym2m.go @@ -129,7 +129,7 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md qm2m.mi = mi diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 878b836b..83168de7 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -72,7 +72,7 @@ type querySet struct { orders []string distinct bool forupdate bool - orm *orm + orm *ormBase ctx context.Context forContext bool } @@ -292,7 +292,7 @@ func (o querySet) WithContext(ctx context.Context) QuerySeter { } // create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { +func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) o.mi = mi o.orm = orm diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 3325a7ea..5e05eded 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -63,7 +63,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { type rawSet struct { query string args []interface{} - orm *orm + orm *ormBase } var _ RawSeter = new(rawSet) @@ -858,7 +858,7 @@ func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { +func newRawSet(orm *ormBase, query string, args []interface{}) RawSeter { o := new(rawSet) o.query = query o.args = args diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index eac7b33a..b7b2d9a7 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2026,24 +2026,24 @@ func TestTransaction(t *testing.T) { // this test worked when database support transaction o := NewOrm() - err := o.Begin() + to, err := o.Begin() throwFail(t, err) var names = []string{"1", "2", "3"} var tag Tag tag.Name = names[0] - id, err := o.Insert(&tag) + id, err := to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + num, err := to.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) throwFail(t, AssertIs(num, 1)) switch { case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { id, err = res.LastInsertId() @@ -2052,22 +2052,22 @@ func TestTransaction(t *testing.T) { } } - err = o.Rollback() + err = to.Rollback() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) - err = o.Begin() + to, err = o.Begin() throwFail(t, err) tag.Name = "commit" - id, err = o.Insert(&tag) + id, err = to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - o.Commit() + to.Commit() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() @@ -2086,15 +2086,15 @@ func TestTransactionIsolationLevel(t *testing.T) { o2 := NewOrm() // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to1, err := o1.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to2, err := o2.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) // o1 insert tag var tag Tag tag.Name = "test-transaction" - id, err := o1.Insert(&tag) + id, err := to1.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) @@ -2104,15 +2104,15 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(num, 0)) // o1 commit - o1.Commit() + to1.Commit() // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err = to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) // o2 commit and query tag table, get the result - o2.Commit() + to2.Commit() num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -2125,14 +2125,14 @@ func TestTransactionIsolationLevel(t *testing.T) { func TestBeginTxWithContextCanceled(t *testing.T) { o := NewOrm() ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) + to, _ := o.BeginWithCtx(ctx) + id, err := to.Insert(&Tag{Name: "test-context"}) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) // cancel the context before commit to make it error cancel() - err = o.Commit() + err = to.Commit() throwFail(t, AssertIs(err, context.Canceled)) } diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 2fd10774..b7a38826 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -35,35 +35,43 @@ type Fielder interface { RawValue() interface{} } -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) +type TxBeginner interface { + //self control transaction + Begin() (TxOrmer, error) + BeginWithCtx(ctx context.Context) (TxOrmer, error) + BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) + BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) + + //closure control transaction + DoTx(task func(txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error +} + +type TxCommitter interface { + Commit() error + Rollback() error +} + +//Data Manipulation Language +type DML interface { // insert model data to database // for example: // user := new(User) // id, err = Ormer.Insert(user) // user must be a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) + Insert(md interface{}) (int64, error) + InsertWithCtx(ctx context.Context, md interface{}) (int64, error) // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // if colu type is integer : can use(+-*/), string : convert(colu,"value") // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") // if colu type is integer : can use(+-*/), string : colu || "value" InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) // insert some models to database InsertMulti(bulk int, mds interface{}) (int64, error) + InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) // update model to database. // cols set the columns those want to update. // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns @@ -74,63 +82,93 @@ type Ormer interface { // user.Extra.Data = "orm" // num, err = Ormer.Update(&user, "Langs", "Extra") Update(md interface{}, cols ...string) (int64, error) + UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) // delete model in database Delete(md interface{}, cols ...string) (int64, error) + DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) + + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter +} + +// Data Query Language +type DQL interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate( md interface{}, cols ...string) error + ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) + // load related models to md model. // args are limit, offset int and order string. // // example: // Ormer.LoadRelated(post,"Tags") // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated( md interface{}, name string, args ...interface{}) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer + QueryM2M( md interface{}, name string) QueryM2Mer + QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver + // Using(name string) error + DBStats() *sql.DBStats } +type DriverGetter interface { + Driver() Driver +} + +type Ormer interface { + DQL + DML + DriverGetter + TxBeginner +} + +type TxOrmer interface { + DQL + DML + DriverGetter + TxCommitter +} + // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) @@ -229,7 +267,7 @@ type QuerySeter interface { // }) // user slene's name will change to slene2 Update(values Params) (int64, error) // delete from table - //for example: + // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) @@ -314,8 +352,8 @@ type QueryM2Mer interface { // remove models following the origin model relationship // only delete rows from m2m table // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) + // tag3 := &Tag{Id:5,Name: "TestTag3"} + // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool @@ -337,10 +375,10 @@ type RawPreparer interface { // sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) // rs := Ormer.Raw(sql, 1) type RawSeter interface { - //execute sql and get result + // execute sql and get result Exec() (sql.Result, error) - //query data and map to container - //for example: + // query data and map to container + // for example: // var name string // var id int // rs.QueryRow(&id,&name) // id==2 name=="slene" @@ -396,11 +434,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) + // QueryContext(args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier From aefe21b63af934b346960b4a99d560936e2b0b49 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:25:27 +0800 Subject: [PATCH 031/207] complete error log --- pkg/orm/orm.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 3db75751..24632270 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -558,10 +558,14 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task defer func() { if panicked || err != nil { e := _txOrm.Rollback() - logs.Error("rollback transaction failed: %v", e) + if e != nil { + logs.Error("rollback transaction failed: %v,%v", e, panicked) + } } else { e := _txOrm.Commit() - logs.Error("commit transaction failed: %v", e) + if e != nil { + logs.Error("commit transaction failed: %v,%v", e, panicked) + } } }() From 4aad313de7fbf4da9fd74e89d1e722f2702a29b2 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:34:58 +0800 Subject: [PATCH 032/207] do not judge tx status in txOrm --- pkg/orm/orm.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 24632270..8ef761f4 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -530,7 +530,6 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO alias: o.alias, db: &TxDB{tx: tx}, }, - isClosed: false, } var taskTxOrm TxOrmer = _txOrm @@ -577,33 +576,15 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task type txOrm struct { ormBase - isClosed bool - closeMutex sync.Mutex } var _ TxOrmer = new(txOrm) func (t *txOrm) Commit() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Commit() } func (t *txOrm) Rollback() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Rollback() } From b6f7d30f9f6192d6046e6b9db0f6a4fd261a829e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 19:10:57 +0800 Subject: [PATCH 033/207] fix unit test --- pkg/orm/orm_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index b7b2d9a7..54ecc0fd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2099,7 +2099,7 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(id > 0, true)) // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err := to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) From 44460bc4570b58cefd7b4b1c65e8a1610ceefcbc Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 20 Jul 2020 15:23:17 +0000 Subject: [PATCH 034/207] Refactor RegisterDatabase --- orm/db_alias.go | 87 ++++++++----------------------------- orm/orm_alias_adapt_test.go | 46 ++++++++++++++++++++ pkg/common/kv.go | 69 +++++++++++++++++++++++++++++ pkg/common/kv_test.go | 40 +++++++++++++++++ pkg/orm/constant.go | 21 +++++++++ pkg/orm/db_alias.go | 64 ++++++++++++++------------- pkg/orm/db_alias_test.go | 44 +++++++++++++++++++ pkg/orm/models_test.go | 7 ++- 8 files changed, 279 insertions(+), 99 deletions(-) create mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/common/kv.go create mode 100644 pkg/common/kv_test.go create mode 100644 pkg/orm/constant.go create mode 100644 pkg/orm/db_alias_test.go diff --git a/orm/db_alias.go b/orm/db_alias.go index bf6c350c..a84070b4 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,16 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" + orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -63,7 +68,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -119,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -289,82 +294,26 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - // AddAliasWthDB add a aliasName for the drivename +// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err + return orm2.AddAliasWthDB(aliasName, driverName, db) } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - + kvs := make([]common.KV, 0, 2) for i, v := range params { switch i { case 0: - SetMaxIdleConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) case 1: - SetMaxOpenConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) + case 2: + kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) } } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err + return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -424,7 +373,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -444,7 +393,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go new file mode 100644 index 00000000..d7724527 --- /dev/null +++ b/orm/orm_alias_adapt_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" +) + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) + assert.Nil(t, err) +} diff --git a/pkg/common/kv.go b/pkg/common/kv.go new file mode 100644 index 00000000..508e6b5c --- /dev/null +++ b/pkg/common/kv.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +// KV is common structure to store key-value data. +// when you need something like Pair, you can use this +type KV struct { + Key interface{} + Value interface{} +} + +// KVs will store KV collection as map +type KVs struct { + kvs map[interface{}]interface{} +} + +// GetValueOr check whether this contains the key, +// if the key not found, the default value will be return +func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { + v, ok := kvs.kvs[key] + if ok { + return v + } + return defValue +} + +// Contains will check whether contains the key +func (kvs *KVs) Contains(key interface{}) bool { + _, ok := kvs.kvs[key] + return ok +} + +// IfContains is a functional API that if the key is in KVs, the action will be invoked +func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { + v, ok := kvs.kvs[key] + if ok { + action(v) + } + return kvs +} + +// Put store the value +func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { + kvs.kvs[key] = value + return kvs +} + +// NewKVs will create the *KVs instance +func NewKVs(kvs ...KV) *KVs { + res := &KVs{ + kvs: make(map[interface{}]interface{}, len(kvs)), + } + for _, kv := range kvs { + res.kvs[kv.Key] = kv.Value + } + return res +} diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go new file mode 100644 index 00000000..ed7dc7ef --- /dev/null +++ b/pkg/common/kv_test.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKVs(t *testing.T) { + key := "my-key" + kvs := NewKVs(KV{ + Key: key, + Value: 12, + }) + + assert.True(t, kvs.Contains(key)) + + kvs.IfContains(key, func(value interface{}) { + kvs.Put("my-key1", "") + }) + + assert.True(t, kvs.Contains("my-key1")) + + v := kvs.GetValueOr(key, 13) + assert.Equal(t, 12, v) +} diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go new file mode 100644 index 00000000..14f40a7b --- /dev/null +++ b/pkg/orm/constant.go @@ -0,0 +1,21 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +const ( + MaxIdleConnsKey = "MaxIdleConns" + MaxOpenConnsKey = "MaxOpenConns" + ConnMaxLifetimeKey = "ConnMaxLifetime" +) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index b2a72f56..90c5de3c 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,10 +18,12 @@ import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" - "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" ) // DriverType database driver constant int. @@ -63,7 +65,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -122,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -274,16 +276,17 @@ func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interf } type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *DB - DbBaser dbBaser - TZ *time.Location - Engine string + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string } func detectTZ(al *alias) { @@ -378,13 +381,15 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { var ( err error db *sql.DB al *alias ) + kvs := common.NewKVs(params...) + db, err = sql.Open(driverName, dataSource) if err != nil { err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) @@ -400,14 +405,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) e detectTZ(al) - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) end: if err != nil { @@ -454,10 +458,12 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } +} + +func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { + al := getDbAlias(aliasName) + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) } // GetDB Get *sql.DB from registered database by db alias name. @@ -477,7 +483,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -497,7 +503,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go new file mode 100644 index 00000000..a0cdcd44 --- /dev/null +++ b/pkg/orm/db_alias_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/common" +) + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxIdleConnsKey, + Value: 20, + }, common.KV{ + Key: MaxOpenConnsKey, + Value: 300, + }, common.KV{ + Key: ConnMaxLifetimeKey, + Value: time.Minute, + }) + assert.Nil(t, err) + + al := getDbAlias("test-params") + assert.NotNil(t, al) + assert.Equal(t, al.MaxIdleConns, 20) + assert.Equal(t, al.MaxOpenConns, 300) + assert.Equal(t, al.ConnMaxLifetime, time.Minute) +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 79e926d3..f14ee9cf 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -27,6 +27,8 @@ import ( _ "github.com/mattn/go-sqlite3" // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" + + "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -487,7 +489,10 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ + Key:MaxIdleConnsKey, + Value:20, + }) if err != nil{ panic(fmt.Sprintf("can not register database: %v", err)) From a66b9950e7e1db8d64140f5fa4a6559b04d3a207 Mon Sep 17 00:00:00 2001 From: IamCathal Date: Mon, 20 Jul 2020 21:21:59 +0100 Subject: [PATCH 035/207] Add Content-length field for logging --- router.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router.go b/router.go index a993a1af..6a8ac6f7 100644 --- a/router.go +++ b/router.go @@ -1046,7 +1046,7 @@ func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { HTTPReferrer: r.Header.Get("Referer"), HTTPUserAgent: r.Header.Get("User-Agent"), RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: 0, // @todo this one is missing! + BodyBytesSent: r.ContentLength, } logs.AccessLog(record, BConfig.Log.AccessLogsFormat) } From 9c51952db485cb32a6658df173f622f6930199cd Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 22:50:08 +0800 Subject: [PATCH 036/207] Move package --- orm/cmd_utils.go | 10 +- orm/db_alias.go | 87 +- orm/orm_alias_adapt_test.go | 46 - pkg/admin.go | 458 +++++++ pkg/admin_test.go | 239 ++++ pkg/adminui.go | 356 ++++++ pkg/app.go | 496 ++++++++ pkg/beego.go | 123 ++ pkg/build_info.go | 27 + pkg/cache/README.md | 59 + pkg/cache/cache.go | 103 ++ pkg/cache/cache_test.go | 191 +++ pkg/cache/conv.go | 100 ++ pkg/cache/conv_test.go | 143 +++ pkg/cache/file.go | 258 ++++ pkg/cache/memcache/memcache.go | 188 +++ pkg/cache/memcache/memcache_test.go | 108 ++ pkg/cache/memory.go | 256 ++++ pkg/cache/redis/redis.go | 272 +++++ pkg/cache/redis/redis_test.go | 144 +++ pkg/cache/ssdb/ssdb.go | 231 ++++ pkg/cache/ssdb/ssdb_test.go | 104 ++ pkg/config.go | 524 ++++++++ pkg/config/config.go | 242 ++++ pkg/config/config_test.go | 55 + pkg/config/env/env.go | 87 ++ pkg/config/env/env_test.go | 75 ++ pkg/config/fake.go | 134 +++ pkg/config/ini.go | 504 ++++++++ pkg/config/ini_test.go | 190 +++ pkg/config/json.go | 269 +++++ pkg/config/json_test.go | 222 ++++ pkg/config/xml/xml.go | 228 ++++ pkg/config/xml/xml_test.go | 125 ++ pkg/config/yaml/yaml.go | 316 +++++ pkg/config/yaml/yaml_test.go | 115 ++ pkg/config_test.go | 146 +++ pkg/context/acceptencoder.go | 232 ++++ pkg/context/acceptencoder_test.go | 59 + pkg/context/context.go | 263 +++++ pkg/context/context_test.go | 47 + pkg/context/input.go | 689 +++++++++++ pkg/context/input_test.go | 217 ++++ pkg/context/output.go | 408 +++++++ pkg/context/param/conv.go | 78 ++ pkg/context/param/methodparams.go | 69 ++ pkg/context/param/options.go | 37 + pkg/context/param/parsers.go | 149 +++ pkg/context/param/parsers_test.go | 84 ++ pkg/context/renderer.go | 12 + pkg/context/response.go | 27 + pkg/controller.go | 706 +++++++++++ pkg/controller_test.go | 181 +++ pkg/doc.go | 17 + pkg/error.go | 488 ++++++++ pkg/error_test.go | 88 ++ pkg/filter.go | 44 + pkg/filter_test.go | 68 ++ pkg/flash.go | 110 ++ pkg/flash_test.go | 54 + pkg/fs.go | 74 ++ pkg/grace/grace.go | 166 +++ pkg/grace/server.go | 356 ++++++ pkg/hooks.go | 104 ++ pkg/httplib/README.md | 97 ++ pkg/httplib/httplib.go | 654 ++++++++++ pkg/httplib/httplib_test.go | 286 +++++ pkg/log.go | 127 ++ pkg/logs/README.md | 72 ++ pkg/logs/accesslog.go | 83 ++ pkg/logs/alils/alils.go | 186 +++ pkg/logs/alils/config.go | 13 + pkg/logs/alils/log.pb.go | 1038 ++++++++++++++++ pkg/logs/alils/log_config.go | 42 + pkg/logs/alils/log_project.go | 819 +++++++++++++ pkg/logs/alils/log_store.go | 271 +++++ pkg/logs/alils/machine_group.go | 91 ++ pkg/logs/alils/request.go | 62 + pkg/logs/alils/signature.go | 111 ++ pkg/logs/conn.go | 119 ++ pkg/logs/conn_test.go | 79 ++ pkg/logs/console.go | 99 ++ pkg/logs/console_test.go | 64 + pkg/logs/es/es.go | 102 ++ pkg/logs/file.go | 409 +++++++ pkg/logs/file_test.go | 420 +++++++ pkg/logs/jianliao.go | 72 ++ pkg/logs/log.go | 669 +++++++++++ pkg/logs/logger.go | 176 +++ pkg/logs/logger_test.go | 57 + pkg/logs/multifile.go | 119 ++ pkg/logs/multifile_test.go | 78 ++ pkg/logs/slack.go | 60 + pkg/logs/smtp.go | 149 +++ pkg/logs/smtp_test.go | 27 + pkg/metric/prometheus.go | 99 ++ pkg/metric/prometheus_test.go | 42 + pkg/migration/ddl.go | 395 +++++++ pkg/migration/doc.go | 32 + pkg/migration/migration.go | 330 ++++++ pkg/mime.go | 556 +++++++++ pkg/namespace.go | 396 +++++++ pkg/namespace_test.go | 168 +++ pkg/parser.go | 591 +++++++++ pkg/plugins/apiauth/apiauth.go | 165 +++ pkg/plugins/apiauth/apiauth_test.go | 20 + pkg/plugins/auth/basic.go | 107 ++ pkg/plugins/authz/authz.go | 86 ++ pkg/plugins/authz/authz_model.conf | 14 + pkg/plugins/authz/authz_policy.csv | 7 + pkg/plugins/authz/authz_test.go | 107 ++ pkg/plugins/cors/cors.go | 228 ++++ pkg/plugins/cors/cors_test.go | 253 ++++ pkg/policy.go | 97 ++ pkg/router.go | 1052 +++++++++++++++++ pkg/router_test.go | 732 ++++++++++++ pkg/session/README.md | 114 ++ pkg/session/couchbase/sess_couchbase.go | 247 ++++ pkg/session/ledis/ledis_session.go | 173 +++ pkg/session/memcache/sess_memcache.go | 230 ++++ pkg/session/mysql/sess_mysql.go | 228 ++++ pkg/session/postgres/sess_postgresql.go | 243 ++++ pkg/session/redis/sess_redis.go | 261 ++++ pkg/session/redis_cluster/redis_cluster.go | 220 ++++ .../redis_sentinel/sess_redis_sentinel.go | 234 ++++ .../sess_redis_sentinel_test.go | 90 ++ pkg/session/sess_cookie.go | 180 +++ pkg/session/sess_cookie_test.go | 105 ++ pkg/session/sess_file.go | 315 +++++ pkg/session/sess_file_test.go | 387 ++++++ pkg/session/sess_mem.go | 196 +++ pkg/session/sess_mem_test.go | 58 + pkg/session/sess_test.go | 131 ++ pkg/session/sess_utils.go | 207 ++++ pkg/session/session.go | 377 ++++++ pkg/session/ssdb/sess_ssdb.go | 199 ++++ pkg/staticfile.go | 234 ++++ pkg/staticfile_test.go | 99 ++ pkg/swagger/swagger.go | 174 +++ pkg/template.go | 406 +++++++ pkg/template_test.go | 316 +++++ pkg/templatefunc.go | 780 ++++++++++++ pkg/templatefunc_test.go | 380 ++++++ pkg/testdata/Makefile | 2 + pkg/testdata/bindata.go | 296 +++++ pkg/testdata/views/blocks/block.tpl | 3 + pkg/testdata/views/header.tpl | 3 + pkg/testdata/views/index.tpl | 15 + pkg/testing/assertions.go | 15 + pkg/testing/client.go | 65 + pkg/toolbox/healthcheck.go | 48 + pkg/toolbox/profile.go | 184 +++ pkg/toolbox/profile_test.go | 28 + pkg/toolbox/statistics.go | 149 +++ pkg/toolbox/statistics_test.go | 40 + pkg/toolbox/task.go | 640 ++++++++++ pkg/toolbox/task_test.go | 85 ++ pkg/tree.go | 585 +++++++++ pkg/tree_test.go | 306 +++++ pkg/unregroute_test.go | 226 ++++ pkg/utils/caller.go | 25 + pkg/utils/caller_test.go | 28 + pkg/utils/captcha/LICENSE | 19 + pkg/utils/captcha/README.md | 45 + pkg/utils/captcha/captcha.go | 270 +++++ pkg/utils/captcha/image.go | 501 ++++++++ pkg/utils/captcha/image_test.go | 52 + pkg/utils/captcha/siprng.go | 277 +++++ pkg/utils/captcha/siprng_test.go | 33 + pkg/utils/debug.go | 478 ++++++++ pkg/utils/debug_test.go | 46 + pkg/utils/file.go | 101 ++ pkg/utils/file_test.go | 75 ++ pkg/utils/mail.go | 424 +++++++ pkg/utils/mail_test.go | 41 + pkg/utils/pagination/controller.go | 26 + pkg/utils/pagination/doc.go | 58 + pkg/utils/pagination/paginator.go | 189 +++ pkg/utils/pagination/utils.go | 34 + pkg/utils/rand.go | 44 + pkg/utils/rand_test.go | 33 + pkg/utils/safemap.go | 91 ++ pkg/utils/safemap_test.go | 89 ++ pkg/utils/slice.go | 170 +++ pkg/utils/slice_test.go | 29 + pkg/utils/testdata/grepe.test | 7 + pkg/utils/utils.go | 89 ++ pkg/utils/utils_test.go | 36 + pkg/validation/README.md | 147 +++ pkg/validation/util.go | 298 +++++ pkg/validation/util_test.go | 128 ++ pkg/validation/validation.go | 456 +++++++ pkg/validation/validation_test.go | 609 ++++++++++ pkg/validation/validators.go | 738 ++++++++++++ 194 files changed, 39077 insertions(+), 69 deletions(-) delete mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/admin.go create mode 100644 pkg/admin_test.go create mode 100644 pkg/adminui.go create mode 100644 pkg/app.go create mode 100644 pkg/beego.go create mode 100644 pkg/build_info.go create mode 100644 pkg/cache/README.md create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go create mode 100644 pkg/cache/conv.go create mode 100644 pkg/cache/conv_test.go create mode 100644 pkg/cache/file.go create mode 100644 pkg/cache/memcache/memcache.go create mode 100644 pkg/cache/memcache/memcache_test.go create mode 100644 pkg/cache/memory.go create mode 100644 pkg/cache/redis/redis.go create mode 100644 pkg/cache/redis/redis_test.go create mode 100644 pkg/cache/ssdb/ssdb.go create mode 100644 pkg/cache/ssdb/ssdb_test.go create mode 100644 pkg/config.go create mode 100644 pkg/config/config.go create mode 100644 pkg/config/config_test.go create mode 100644 pkg/config/env/env.go create mode 100644 pkg/config/env/env_test.go create mode 100644 pkg/config/fake.go create mode 100644 pkg/config/ini.go create mode 100644 pkg/config/ini_test.go create mode 100644 pkg/config/json.go create mode 100644 pkg/config/json_test.go create mode 100644 pkg/config/xml/xml.go create mode 100644 pkg/config/xml/xml_test.go create mode 100644 pkg/config/yaml/yaml.go create mode 100644 pkg/config/yaml/yaml_test.go create mode 100644 pkg/config_test.go create mode 100644 pkg/context/acceptencoder.go create mode 100644 pkg/context/acceptencoder_test.go create mode 100644 pkg/context/context.go create mode 100644 pkg/context/context_test.go create mode 100644 pkg/context/input.go create mode 100644 pkg/context/input_test.go create mode 100644 pkg/context/output.go create mode 100644 pkg/context/param/conv.go create mode 100644 pkg/context/param/methodparams.go create mode 100644 pkg/context/param/options.go create mode 100644 pkg/context/param/parsers.go create mode 100644 pkg/context/param/parsers_test.go create mode 100644 pkg/context/renderer.go create mode 100644 pkg/context/response.go create mode 100644 pkg/controller.go create mode 100644 pkg/controller_test.go create mode 100644 pkg/doc.go create mode 100644 pkg/error.go create mode 100644 pkg/error_test.go create mode 100644 pkg/filter.go create mode 100644 pkg/filter_test.go create mode 100644 pkg/flash.go create mode 100644 pkg/flash_test.go create mode 100644 pkg/fs.go create mode 100644 pkg/grace/grace.go create mode 100644 pkg/grace/server.go create mode 100644 pkg/hooks.go create mode 100644 pkg/httplib/README.md create mode 100644 pkg/httplib/httplib.go create mode 100644 pkg/httplib/httplib_test.go create mode 100644 pkg/log.go create mode 100644 pkg/logs/README.md create mode 100644 pkg/logs/accesslog.go create mode 100644 pkg/logs/alils/alils.go create mode 100755 pkg/logs/alils/config.go create mode 100755 pkg/logs/alils/log.pb.go create mode 100755 pkg/logs/alils/log_config.go create mode 100755 pkg/logs/alils/log_project.go create mode 100755 pkg/logs/alils/log_store.go create mode 100755 pkg/logs/alils/machine_group.go create mode 100755 pkg/logs/alils/request.go create mode 100755 pkg/logs/alils/signature.go create mode 100644 pkg/logs/conn.go create mode 100644 pkg/logs/conn_test.go create mode 100644 pkg/logs/console.go create mode 100644 pkg/logs/console_test.go create mode 100644 pkg/logs/es/es.go create mode 100644 pkg/logs/file.go create mode 100644 pkg/logs/file_test.go create mode 100644 pkg/logs/jianliao.go create mode 100644 pkg/logs/log.go create mode 100644 pkg/logs/logger.go create mode 100644 pkg/logs/logger_test.go create mode 100644 pkg/logs/multifile.go create mode 100644 pkg/logs/multifile_test.go create mode 100644 pkg/logs/slack.go create mode 100644 pkg/logs/smtp.go create mode 100644 pkg/logs/smtp_test.go create mode 100644 pkg/metric/prometheus.go create mode 100644 pkg/metric/prometheus_test.go create mode 100644 pkg/migration/ddl.go create mode 100644 pkg/migration/doc.go create mode 100644 pkg/migration/migration.go create mode 100644 pkg/mime.go create mode 100644 pkg/namespace.go create mode 100644 pkg/namespace_test.go create mode 100644 pkg/parser.go create mode 100644 pkg/plugins/apiauth/apiauth.go create mode 100644 pkg/plugins/apiauth/apiauth_test.go create mode 100644 pkg/plugins/auth/basic.go create mode 100644 pkg/plugins/authz/authz.go create mode 100644 pkg/plugins/authz/authz_model.conf create mode 100644 pkg/plugins/authz/authz_policy.csv create mode 100644 pkg/plugins/authz/authz_test.go create mode 100644 pkg/plugins/cors/cors.go create mode 100644 pkg/plugins/cors/cors_test.go create mode 100644 pkg/policy.go create mode 100644 pkg/router.go create mode 100644 pkg/router_test.go create mode 100644 pkg/session/README.md create mode 100644 pkg/session/couchbase/sess_couchbase.go create mode 100644 pkg/session/ledis/ledis_session.go create mode 100644 pkg/session/memcache/sess_memcache.go create mode 100644 pkg/session/mysql/sess_mysql.go create mode 100644 pkg/session/postgres/sess_postgresql.go create mode 100644 pkg/session/redis/sess_redis.go create mode 100644 pkg/session/redis_cluster/redis_cluster.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/session/sess_cookie.go create mode 100644 pkg/session/sess_cookie_test.go create mode 100644 pkg/session/sess_file.go create mode 100644 pkg/session/sess_file_test.go create mode 100644 pkg/session/sess_mem.go create mode 100644 pkg/session/sess_mem_test.go create mode 100644 pkg/session/sess_test.go create mode 100644 pkg/session/sess_utils.go create mode 100644 pkg/session/session.go create mode 100644 pkg/session/ssdb/sess_ssdb.go create mode 100644 pkg/staticfile.go create mode 100644 pkg/staticfile_test.go create mode 100644 pkg/swagger/swagger.go create mode 100644 pkg/template.go create mode 100644 pkg/template_test.go create mode 100644 pkg/templatefunc.go create mode 100644 pkg/templatefunc_test.go create mode 100644 pkg/testdata/Makefile create mode 100644 pkg/testdata/bindata.go create mode 100644 pkg/testdata/views/blocks/block.tpl create mode 100644 pkg/testdata/views/header.tpl create mode 100644 pkg/testdata/views/index.tpl create mode 100644 pkg/testing/assertions.go create mode 100644 pkg/testing/client.go create mode 100644 pkg/toolbox/healthcheck.go create mode 100644 pkg/toolbox/profile.go create mode 100644 pkg/toolbox/profile_test.go create mode 100644 pkg/toolbox/statistics.go create mode 100644 pkg/toolbox/statistics_test.go create mode 100644 pkg/toolbox/task.go create mode 100644 pkg/toolbox/task_test.go create mode 100644 pkg/tree.go create mode 100644 pkg/tree_test.go create mode 100644 pkg/unregroute_test.go create mode 100644 pkg/utils/caller.go create mode 100644 pkg/utils/caller_test.go create mode 100644 pkg/utils/captcha/LICENSE create mode 100644 pkg/utils/captcha/README.md create mode 100644 pkg/utils/captcha/captcha.go create mode 100644 pkg/utils/captcha/image.go create mode 100644 pkg/utils/captcha/image_test.go create mode 100644 pkg/utils/captcha/siprng.go create mode 100644 pkg/utils/captcha/siprng_test.go create mode 100644 pkg/utils/debug.go create mode 100644 pkg/utils/debug_test.go create mode 100644 pkg/utils/file.go create mode 100644 pkg/utils/file_test.go create mode 100644 pkg/utils/mail.go create mode 100644 pkg/utils/mail_test.go create mode 100644 pkg/utils/pagination/controller.go create mode 100644 pkg/utils/pagination/doc.go create mode 100644 pkg/utils/pagination/paginator.go create mode 100644 pkg/utils/pagination/utils.go create mode 100644 pkg/utils/rand.go create mode 100644 pkg/utils/rand_test.go create mode 100644 pkg/utils/safemap.go create mode 100644 pkg/utils/safemap_test.go create mode 100644 pkg/utils/slice.go create mode 100644 pkg/utils/slice_test.go create mode 100644 pkg/utils/testdata/grepe.test create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go create mode 100644 pkg/validation/README.md create mode 100644 pkg/validation/util.go create mode 100644 pkg/validation/util_test.go create mode 100644 pkg/validation/validation.go create mode 100644 pkg/validation/validation_test.go create mode 100644 pkg/validation/validators.go diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index eac85091..61f17346 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - // if fi.initial.String() != "" { + //if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - // } + //} // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index a84070b4..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,21 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" + lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" - - lru "github.com/hashicorp/golang-lru" - - "github.com/astaxie/beego/pkg/common" - orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -68,7 +63,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, // https://github.com/rana/ora + "ora": DROracle, //https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -124,7 +119,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -// su must call release to release *sql.Stmt after using +//su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -294,26 +289,82 @@ func detectTZ(al *alias) { } } +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + // AddAliasWthDB add a aliasName for the drivename -// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - return orm2.AddAliasWthDB(aliasName, driverName, db) + _, err := addAliasWthDB(aliasName, driverName, db) + return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - kvs := make([]common.KV, 0, 2) + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + for i, v := range params { switch i { case 0: - kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) + SetMaxIdleConns(al.Name, v) case 1: - kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) - case 2: - kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) + SetMaxOpenConns(al.Name, v) } } - return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -373,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -393,7 +444,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -// garbage recycle for stmt +//garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go deleted file mode 100644 index d7724527..00000000 --- a/orm/orm_alias_adapt_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "os" - "testing" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" -) - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) - assert.Nil(t, err) -} diff --git a/pkg/admin.go b/pkg/admin.go new file mode 100644 index 00000000..db52647e --- /dev/null +++ b/pkg/admin.go @@ -0,0 +1,458 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "reflect" + "strconv" + "text/template" + "time" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// BeeAdminApp is the default adminApp used by admin module. +var beeAdminApp *adminApp + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + beeAdminApp = &adminApp{ + routers: make(map[string]http.HandlerFunc), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Route("/", adminIndex) + beeAdminApp.Route("/qps", qpsIndex) + beeAdminApp.Route("/prof", profIndex) + beeAdminApp.Route("/healthcheck", healthcheck) + beeAdminApp.Route("/task", taskStatus) + beeAdminApp.Route("/listconf", listConf) + beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } +} + +// AdminIndex is the default http.Handler for admin module. +// it matches url pattern "/". +func adminIndex(rw http.ResponseWriter, _ *http.Request) { + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func qpsIndex(rw http.ResponseWriter, _ *http.Request) { + data := make(map[interface{}]interface{}) + data["Content"] = toolbox.StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func listConf(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + filterTypes = []string{} + filterTypeData = make(M) + ) + + if BeeApp.Handlers.enableFilter { + var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router"} { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { + filterType = fr + filterTypes = append(filterTypes, filterType) + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[filterType] = resultList + } + } + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func list(root string, p interface{}, m M) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + +// PrintTree prints all registered routers. +func PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range BeeApp.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + var result = []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func profIndex(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + toolbox.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func healthcheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range toolbox.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response + +} + +func writeJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func taskStatus(rw http.ResponseWriter, req *http.Request) { + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + if t, ok := toolbox.AdminTaskList[taskname]; ok { + if err := t.Run(); err != nil { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} + } + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + } else { + data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + } + } + + // List Tasks + content := make(M) + resultList := new([][]string) + var fields = []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + for tname, tk := range toolbox.AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec()), + template.HTMLEscapeString(tk.GetStatus()), + template.HTMLEscapeString(tk.GetPrev().String()), + } + *resultList = append(*resultList, result) + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +// adminApp is an http.HandlerFunc map used as beeAdminApp. +type adminApp struct { + routers map[string]http.HandlerFunc +} + +// Route adds http.HandlerFunc to adminApp with url pattern. +func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { + admin.routers[pattern] = f +} + +// Run adminApp http server. +// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. +func (admin *adminApp) Run() { + if len(toolbox.AdminTaskList) > 0 { + toolbox.StartTask() + } + addr := BConfig.Listen.AdminAddr + + if BConfig.Listen.AdminPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) + } + for p, f := range admin.routers { + http.Handle(p, f) + } + logs.Info("Admin server Running on %s", addr) + + var err error + if BConfig.Listen.Graceful { + err = grace.ListenAndServe(addr, nil) + } else { + err = http.ListenAndServe(addr, nil) + } + if err != nil { + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + } +} diff --git a/pkg/admin_test.go b/pkg/admin_test.go new file mode 100644 index 00000000..3f3612e4 --- /dev/null +++ b/pkg/admin_test.go @@ -0,0 +1,239 @@ +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/astaxie/beego/toolbox" +) + +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + +func TestList_01(t *testing.T) { + m := make(M) + list("BConfig", BConfig, m) + t.Log(m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k]) != fmt.Sprint(v) { + t.Log(k, "old-key", v, "new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() M { + m := make(M) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize + m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs + m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} + +func TestWriteJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + writeJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Could not decode response body into slice.") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { + t.Errorf("handler returned unexpected body: got %v want %v", + decodedResponseBody, expectedResponseBody) + } + +} diff --git a/pkg/adminui.go b/pkg/adminui.go new file mode 100644 index 00000000..cdcdef33 --- /dev/null +++ b/pkg/adminui.go @@ -0,0 +1,356 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var indexTpl = ` +{{define "content"}} +

Beego Admin Dashboard

+

+For detail usage please check our document: +

+

+Toolbox +

+

+Live Monitor +

+{{.Content}} +{{end}}` + +var profillingTpl = ` +{{define "content"}} +

{{.Title}}

+
+
{{.Content}}
+
+{{end}}` + +var defaultScriptsTpl = `` + +var gcAjaxTpl = ` +{{define "scripts"}} + +{{end}} +` + +var qpsTpl = `{{define "content"}} +

Requests statistics

+ + + + {{range .Content.Fields}} + + {{end}} + + + + + {{range $i, $elem := .Content.Data}} + + + + + + + + + + + {{end}} + + +
+ {{.}} +
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
+{{end}}` + +var configTpl = ` +{{define "content"}} +

Configurations

+
+{{range $index, $elem := .Content}}
+{{$index}}={{$elem}}
+{{end}}
+
+{{end}} +` + +var routerAndFilterTpl = `{{define "content"}} + + +

{{.Title}}

+ +{{range .Content.Methods}} + +
+
{{.}}
+
+ + + + {{range $.Content.Fields}} + + {{end}} + + + + + {{$slice := index $.Content.Data .}} + {{range $i, $elem := $slice}} + + + {{range $elem}} + + {{end}} + + + {{end}} + + +
+ {{.}} +
+ {{.}} +
+
+
+{{end}} + + +{{end}}` + +var tasksTpl = `{{define "content"}} + +

{{.Title}}

+ +{{if .Message }} +{{ $messageType := index .Message 0}} +

+{{index .Message 1}} +

+{{end}} + + + + + +{{range .Content.Fields}} + +{{end}} + + + + +{{range $i, $slice := .Content.Data}} + + {{range $slice}} + + {{end}} + + +{{end}} + +
+{{.}} +
+ {{.}} + + Run +
+ +{{end}}` + +var healthCheckTpl = ` +{{define "content"}} + +

{{.Title}}

+ + + +{{range .Content.Fields}} + +{{end}} + + + +{{range $i, $slice := .Content.Data}} + {{ $header := index $slice 0}} + {{ if eq "success" $header}} + + {{else if eq "error" $header}} + + {{else}} + + {{end}} + {{range $j, $elem := $slice}} + {{if ne $j 0}} + + {{end}} + {{end}} + + +{{end}} + + +
+ {{.}} +
+ {{$elem}} + + {{$header}} +
+{{end}}` + +// The base dashboardTpl +var dashboardTpl = ` + + + + + + + + + + +Welcome to Beego Admin Dashboard + + + + + + + + + + + + + +
+{{template "content" .}} +
+ + + + + + + +{{template "scripts" .}} + + +` diff --git a/pkg/app.go b/pkg/app.go new file mode 100644 index 00000000..f3fe6f7b --- /dev/null +++ b/pkg/app.go @@ -0,0 +1,496 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/fcgi" + "os" + "path" + "strings" + "time" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" + "golang.org/x/crypto/acme/autocert" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = NewApp() +} + +// App defines beego application with a new PatternServeMux. +type App struct { + Handlers *ControllerRegister + Server *http.Server +} + +// NewApp returns a new beego application. +func NewApp() *App { + cr := NewControllerRegister() + app := &App{Handlers: cr, Server: &http.Server{}} + return app +} + +// MiddleWare function for http.Handler +type MiddleWare func(http.Handler) http.Handler + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + addr := BConfig.Listen.HTTPAddr + + if BConfig.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) + } + + var ( + err error + l net.Listener + endRunning = make(chan bool, 1) + ) + + // run cgi server + if BConfig.Listen.EnableFcgi { + if BConfig.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O + logs.Info("Use FCGI via standard I/O") + } else { + logs.Critical("Cannot use FCGI via standard I/O", err) + } + return + } + if BConfig.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + logs.Critical("Listen: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + logs.Critical("fcgi.Serve: ", err) + } + return + } + + app.Server.Handler = app.Handlers + for i := len(mws) - 1; i >= 0; i-- { + if mws[i] == nil { + continue + } + app.Server.Handler = mws[i](app.Server.Handler) + } + app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") + + // run graceful mode + if BConfig.Listen.Graceful { + httpsAddr := BConfig.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.EnableMutualHTTPS { + if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } else { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } + endRunning <- true + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.ListenTCP4 { + server.Network = "tcp4" + } + if err := server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + <-endRunning + return + } + + // run normal mode + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } else if BConfig.Listen.EnableHTTP { + logs.Info("Start https server error, conflict with http. Please reset https port") + return + } + logs.Info("https server Running on https://%s", app.Server.Addr) + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } else if BConfig.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + if err != nil { + logs.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + } + if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + + } + if BConfig.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + logs.Info("http server Running on http://%s", app.Server.Addr) + if BConfig.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + if err := app.Server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } + <-endRunning +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + BeeApp.Handlers.Add(rootpath, c, mappingMethods...) + return BeeApp +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + subPaths := splitPath(fixedRoute) + if method == "" || method == "*" { + for m := range HTTPMETHOD { + if _, ok := BeeApp.Handlers.routers[m]; !ok { + continue + } + if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) + continue + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) + } + return BeeApp + } + // Single HTTP method + um := strings.ToUpper(method) + if _, ok := BeeApp.Handlers.routers[um]; ok { + if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) + return BeeApp + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) + } + return BeeApp +} + +func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { + for i := range entryPointTree.fixrouters { + if entryPointTree.fixrouters[i].prefix == paths[0] { + if len(paths) == 1 { + if len(entryPointTree.fixrouters[i].fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.fixrouters[i].leaves) > 0 { + entryPointTree.fixrouters[i].leaves[0] = nil + entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] + } + } else { + // Remove the *Tree from the fixrouters slice + entryPointTree.fixrouters[i] = nil + + if i == len(entryPointTree.fixrouters)-1 { + entryPointTree.fixrouters = entryPointTree.fixrouters[:i] + } else { + entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) + } + } + return + } + findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) + } + } +} + +func findAndRemoveSingleTree(entryPointTree *Tree) { + if entryPointTree == nil { + return + } + if len(entryPointTree.fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.leaves) > 0 { + entryPointTree.leaves[0] = nil + entryPointTree.leaves = entryPointTree.leaves[1:] + } + } +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +//} +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + BeeApp.Handlers.Include(cList...) + return BeeApp +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + Router(rootpath, c) + Router(path.Join(rootpath, ":objectId"), c) + return BeeApp +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + BeeApp.Handlers.AddAuto(c) + return BeeApp +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.Handlers.AddAutoPrefix(prefix, c) + return BeeApp +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Get(rootpath, f) + return BeeApp +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Post(rootpath, f) + return BeeApp +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Delete(rootpath, f) + return BeeApp +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Put(rootpath, f) + return BeeApp +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Head(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Patch(rootpath, f) + return BeeApp +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Any(rootpath, f) + return BeeApp +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + BeeApp.Handlers.Handler(rootpath, h, options...) + return BeeApp +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) + return BeeApp +} diff --git a/pkg/beego.go b/pkg/beego.go new file mode 100644 index 00000000..8ebe0bab --- /dev/null +++ b/pkg/beego.go @@ -0,0 +1,123 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "os" + "path/filepath" + "strconv" + "strings" +) + +const ( + // VERSION represent beego web framework version. + VERSION = "1.12.2" + + // DEV is for develop + DEV = "dev" + // PROD is for production + PROD = "prod" +) + +// M is Map shortcut +type M map[string]interface{} + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + hooks = append(hooks, hf...) +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + + initBeforeHTTPRun() + + if len(params) > 0 && params[0] != "" { + strs := strings.Split(params[0], ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BConfig.Listen.Domains = params + } + + BeeApp.Run() +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + initBeforeHTTPRun() + + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + BConfig.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BeeApp.Run(mws...) +} + +func initBeforeHTTPRun() { + //init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } + } +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + path := filepath.Join(ap, "conf", "app.conf") + os.Chdir(ap) + InitBeegoBeforeTest(path) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { + panic(err) + } + BConfig.RunMode = "test" + initBeforeHTTPRun() +} diff --git a/pkg/build_info.go b/pkg/build_info.go new file mode 100644 index 00000000..6dc2835e --- /dev/null +++ b/pkg/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/cache/README.md b/pkg/cache/README.md new file mode 100644 index 00000000..b467760a --- /dev/null +++ b/pkg/cache/README.md @@ -0,0 +1,59 @@ +## cache +cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/cache + + +## What adapters are supported? + +As of now this cache support memory, Memcache and Redis. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/cache" + ) + +Then init a Cache (example with memory adapter) + + bm, err := cache.NewCache("memory", `{"interval":60}`) + +Use it like this: + + bm.Put("astaxie", 1, 10 * time.Second) + bm.Get("astaxie") + bm.IsExist("astaxie") + bm.Delete("astaxie") + + +## Memory adapter + +Configure memory adapter like this: + + {"interval":60} + +interval means the gc time. The cache will check at each time interval, whether item has expired. + + +## Memcache adapter + +Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. + +Configure like this: + + {"conn":"127.0.0.1:11211"} + + +## Redis adapter + +Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. + +Configure like this: + + {"conn":":6039"} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 00000000..82585c4e --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,103 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + "time" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/cache/conv.go b/pkg/cache/conv.go new file mode 100644 index 00000000..87800586 --- /dev/null +++ b/pkg/cache/conv.go @@ -0,0 +1,100 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/pkg/cache/conv_test.go b/pkg/cache/conv_test.go new file mode 100644 index 00000000..b90e224a --- /dev/null +++ b/pkg/cache/conv_test.go @@ -0,0 +1,143 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" +) + +func TestGetString(t *testing.T) { + var t1 = "test1" + if "test1" != GetString(t1) { + t.Error("get string from string error") + } + var t2 = []byte("test2") + if "test2" != GetString(t2) { + t.Error("get string from byte array error") + } + var t3 = 1 + if "1" != GetString(t3) { + t.Error("get string from int error") + } + var t4 int64 = 1 + if "1" != GetString(t4) { + t.Error("get string from int64 error") + } + var t5 = 1.1 + if "1.1" != GetString(t5) { + t.Error("get string from float64 error") + } + + if "" != GetString(nil) { + t.Error("get string from nil error") + } +} + +func TestGetInt(t *testing.T) { + var t1 = 1 + if 1 != GetInt(t1) { + t.Error("get int from int error") + } + var t2 int32 = 32 + if 32 != GetInt(t2) { + t.Error("get int from int32 error") + } + var t3 int64 = 64 + if 64 != GetInt(t3) { + t.Error("get int from int64 error") + } + var t4 = "128" + if 128 != GetInt(t4) { + t.Error("get int from num string error") + } + if 0 != GetInt(nil) { + t.Error("get int from nil error") + } +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + var t1 = 1 + if i != GetInt64(t1) { + t.Error("get int64 from int error") + } + var t2 int32 = 1 + if i != GetInt64(t2) { + t.Error("get int64 from int32 error") + } + var t3 int64 = 1 + if i != GetInt64(t3) { + t.Error("get int64 from int64 error") + } + var t4 = "1" + if i != GetInt64(t4) { + t.Error("get int64 from num string error") + } + if 0 != GetInt64(nil) { + t.Error("get int64 from nil") + } +} + +func TestGetFloat64(t *testing.T) { + var f = 1.11 + var t1 float32 = 1.11 + if f != GetFloat64(t1) { + t.Error("get float64 from float32 error") + } + var t2 = 1.11 + if f != GetFloat64(t2) { + t.Error("get float64 from float64 error") + } + var t3 = "1.11" + if f != GetFloat64(t3) { + t.Error("get float64 from string error") + } + + var f2 float64 = 1 + var t4 = 1 + if f2 != GetFloat64(t4) { + t.Error("get float64 from int error") + } + + if 0 != GetFloat64(nil) { + t.Error("get float64 from nil error") + } +} + +func TestGetBool(t *testing.T) { + var t1 = true + if !GetBool(t1) { + t.Error("get bool from bool error") + } + var t2 = "true" + if !GetBool(t2) { + t.Error("get bool from string error") + } + if GetBool(nil) { + t.Error("get bool from nil error") + } +} + +func byteArrayEquals(a []byte, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/pkg/cache/file.go b/pkg/cache/file.go new file mode 100644 index 00000000..6f12d3ee --- /dev/null +++ b/pkg/cache/file.go @@ -0,0 +1,258 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + Lastaccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +func (fc *FileCache) StartAndGC(config string) error { + + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == time.Duration(fc.EmbedExpiry) { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.Lastaccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go new file mode 100644 index 00000000..19116bfa --- /dev/null +++ b/pkg/cache/memcache/memcache.go @@ -0,0 +1,188 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/memcache" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package memcache + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/astaxie/beego/cache" + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache Memcache adapter. +type Cache struct { + conn *memcache.Client + conninfo []string +} + +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + if item, err := rc.conn.Get(key); err == nil { + return item.Value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var rv []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv + } + } + mv, err := rc.conn.GetMulti(keys) + if err == nil { + for _, v := range mv { + rv = append(rv, v.Value) + } + return rv + } + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv +} + +// Put put value to memcache. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} + if v, ok := val.([]byte); ok { + item.Value = v + } else if str, ok := val.(string); ok { + item.Value = []byte(str) + } else { + return errors.New("val only support string and []byte") + } + return rc.conn.Set(&item) +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.Delete(key) +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Increment(key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Decrement(key, 1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + _, err := rc.conn.Get(key) + return err == nil +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.FlushAll() +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + rc.conn = memcache.New(rc.conninfo...) + return nil +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go new file mode 100644 index 00000000..d9129b69 --- /dev/null +++ b/pkg/cache/memcache/memcache_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + _ "github.com/bradfitz/gomemcache/memcache" + + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestMemcacheCache(t *testing.T) { + bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie").([]byte); string(v) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go new file mode 100644 index 00000000..d8314e3c --- /dev/null +++ b/pkg/cache/memory.go @@ -0,0 +1,256 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val + 1 + case int32: + itm.val = val + 1 + case int64: + itm.val = val + 1 + case uint: + itm.val = val + 1 + case uint32: + itm.val = val + 1 + case uint64: + itm.val = val + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val - 1 + case int64: + itm.val = val - 1 + case int32: + itm.val = val - 1 + case uint: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + bc.RLock() + if bc.items == nil { + bc.RUnlock() + return + } + bc.RUnlock() + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 00000000..56faf211 --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,272 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/redis" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package redis + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/astaxie/beego/cache" + "strings" +) + +var ( + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" +) + +// Cache is Redis cache adapter. +type Cache struct { + p *redis.Pool // redis connection pool + conninfo string + dbNum int + key string + password string + maxIdle int + + //the timeout to a value less than the redis server's timeout. + timeout time.Duration +} + +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} +} + +// actually do the redis cmds, args[0] must be the key name. +func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { + if len(args) < 1 { + return nil, errors.New("missing required arguments") + } + args[0] = rc.associate(args[0]) + c := rc.p.Get() + defer c.Close() + + return c.Do(commandName, args...) +} + +// associate with config key. +func (rc *Cache) associate(originKey interface{}) string { + return fmt.Sprintf("%s:%s", rc.key, originKey) +} + +// Get cache from redis. +func (rc *Cache) Get(key string) interface{} { + if v, err := rc.do("GET", key); err == nil { + return v + } + return nil +} + +// GetMulti get cache from redis. +func (rc *Cache) GetMulti(keys []string) []interface{} { + c := rc.p.Get() + defer c.Close() + var args []interface{} + for _, key := range keys { + args = append(args, rc.associate(key)) + } + values, err := redis.Values(c.Do("MGET", args...)) + if err != nil { + return nil + } + return values +} + +// Put put cache to redis. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) + return err +} + +// Delete delete cache in redis. +func (rc *Cache) Delete(key string) error { + _, err := rc.do("DEL", key) + return err +} + +// IsExist check cache's existence in redis. +func (rc *Cache) IsExist(key string) bool { + v, err := redis.Bool(rc.do("EXISTS", key)) + if err != nil { + return false + } + return v +} + +// Incr increase counter in redis. +func (rc *Cache) Incr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, 1)) + return err +} + +// Decr decrease counter in redis. +func (rc *Cache) Decr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, -1)) + return err +} + +// ClearAll clean all cache in redis. delete this redis collection. +func (rc *Cache) ClearAll() error { + cachedKeys, err := rc.Scan(rc.key + ":*") + if err != nil { + return err + } + c := rc.p.Get() + defer c.Close() + for _, str := range cachedKeys { + if _, err = c.Do("DEL", str); err != nil { + return err + } + } + return err +} + +// Scan scan all keys matching the pattern. a better choice than `keys` +func (rc *Cache) Scan(pattern string) (keys []string, err error) { + c := rc.p.Get() + defer c.Close() + var ( + cursor uint64 = 0 // start + result []interface{} + list []string + ) + for { + result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) + if err != nil { + return + } + list, err = redis.Strings(result[1], nil) + if err != nil { + return + } + keys = append(keys, list...) + cursor, err = redis.Uint64(result[0], nil) + if err != nil { + return + } + if cursor == 0 { // over + return + } + } +} + +// StartAndGC start redis cache adapter. +// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} +// the cache item in redis are stored forever, +// so no gc operation. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + + if _, ok := cf["key"]; !ok { + cf["key"] = DefaultKey + } + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + + // Format redis://@: + cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) + if i := strings.Index(cf["conn"], "@"); i > -1 { + cf["password"] = cf["conn"][0:i] + cf["conn"] = cf["conn"][i+1:] + } + + if _, ok := cf["dbNum"]; !ok { + cf["dbNum"] = "0" + } + if _, ok := cf["password"]; !ok { + cf["password"] = "" + } + if _, ok := cf["maxIdle"]; !ok { + cf["maxIdle"] = "3" + } + if _, ok := cf["timeout"]; !ok { + cf["timeout"] = "180s" + } + rc.key = cf["key"] + rc.conninfo = cf["conn"] + rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) + rc.password = cf["password"] + rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) + + if v, err := time.ParseDuration(cf["timeout"]); err == nil { + rc.timeout = v + } else { + rc.timeout = 180 * time.Second + } + + rc.connectInit() + + c := rc.p.Get() + defer c.Close() + + return c.Err() +} + +// connect to redis. +func (rc *Cache) connectInit() { + dialFunc := func() (c redis.Conn, err error) { + c, err = redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, err + } + + if rc.password != "" { + if _, err := c.Do("AUTH", rc.password); err != nil { + c.Close() + return nil, err + } + } + + _, selecterr := c.Do("SELECT", rc.dbNum) + if selecterr != nil { + c.Close() + return nil, selecterr + } + return + } + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: rc.maxIdle, + IdleTimeout: rc.timeout, + Dial: dialFunc, + } +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 00000000..60a19180 --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,144 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "testing" + "time" + + "github.com/astaxie/beego/cache" + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" +) + +func TestRedisCache(t *testing.T) { + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[0], nil); v != "author" { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[1], nil); v != "author1" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} + +func TestCache_Scan(t *testing.T) { + timeoutDuration := 10 * time.Second + // init + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + // insert all + for i := 0; i < 10000; i++ { + if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + t.Error("set Error", err) + } + } + // scan all for the first time + keys, err := bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + + assert.Equal(t, 10000, len(keys), "scan all error") + + // clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } + + // scan all for the second time + keys, err = bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + if len(keys) != 0 { + t.Error("scan all err") + } +} diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go new file mode 100644 index 00000000..fa2ce04b --- /dev/null +++ b/pkg/cache/ssdb/ssdb.go @@ -0,0 +1,231 @@ +package ssdb + +import ( + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "github.com/ssdb/gossdb/ssdb" + + "github.com/astaxie/beego/cache" +) + +// Cache SSDB adapter +type Cache struct { + conn *ssdb.Client + conninfo []string +} + +//NewSsdbCache create new ssdb adapter. +func NewSsdbCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil + } + } + value, err := rc.conn.Get(key) + if err == nil { + return value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var values []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + values = append(values, err) + } + return values + } + } + res, err := rc.conn.Do("multi_get", keys) + resSize := len(res) + if err == nil { + for i := 1; i < resSize; i += 2 { + values = append(values, res[i+1]) + } + return values + } + for i := 0; i < size; i++ { + values = append(values, err) + } + return values +} + +// DelMulti get value from memcache. +func (rc *Cache) DelMulti(keys []string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("multi_del", keys) + return err +} + +// Put put value to memcache. only support string. +func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + v, ok := value.(string) + if !ok { + return errors.New("value must string") + } + var resp []string + var err error + ttl := int(timeout / time.Second) + if ttl < 0 { + resp, err = rc.conn.Do("set", key, v) + } else { + resp, err = rc.conn.Do("setx", key, v, ttl) + } + if err != nil { + return err + } + if len(resp) == 2 && resp[0] == "ok" { + return nil + } + return errors.New("bad response") +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Del(key) + return err +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, -1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + resp, err := rc.conn.Do("exists", key) + if err != nil { + return false + } + if len(resp) == 2 && resp[1] == "1" { + return true + } + return false + +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + keyStart, keyEnd, limit := "", "", 50 + resp, err := rc.Scan(keyStart, keyEnd, limit) + for err == nil { + size := len(resp) + if size == 1 { + return nil + } + keys := []string{} + for i := 1; i < size; i += 2 { + keys = append(keys, resp[i]) + } + _, e := rc.conn.Do("multi_del", keys) + if e != nil { + return e + } + keyStart = resp[size-2] + resp, err = rc.Scan(keyStart, keyEnd, limit) + } + return err +} + +// Scan key all cached in ssdb. +func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil, err + } + } + resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) + if err != nil { + return nil, err + } + return resp, nil +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + conninfoArray := strings.Split(rc.conninfo[0], ":") + host := conninfoArray[0] + port, e := strconv.Atoi(conninfoArray[1]) + if e != nil { + return e + } + var err error + rc.conn, err = ssdb.Connect(host, port) + return err +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/cache/ssdb/ssdb_test.go new file mode 100644 index 00000000..dd474960 --- /dev/null +++ b/pkg/cache/ssdb/ssdb_test.go @@ -0,0 +1,104 @@ +package ssdb + +import ( + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) + if err != nil { + t.Error("init err") + } + + // test put and exist + if ssdb.IsExist("ssdb") { + t.Error("check err") + } + timeoutDuration := 10 * time.Second + //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + + // Get test done + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v := ssdb.Get("ssdb"); v != "ssdb" { + t.Error("get Error") + } + + //inc/dec test done + if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if err = ssdb.Incr("ssdb"); err != nil { + t.Error("incr Error", err) + } + + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + + if err = ssdb.Decr("ssdb"); err != nil { + t.Error("decr error") + } + + // test del + if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + if err := ssdb.Delete("ssdb"); err == nil { + if ssdb.IsExist("ssdb") { + t.Error("delete err") + } + } + + //test string + if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + if v := ssdb.Get("ssdb").(string); v != "ssdb" { + t.Error("get err") + } + + //test GetMulti done + if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb1") { + t.Error("check err") + } + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + if len(vv) != 2 { + t.Error("getmulti error") + } + if vv[0].(string) != "ssdb" { + t.Error("getmulti error") + } + if vv[1].(string) != "ssdb1" { + t.Error("getmulti error") + } + + // test clear all done + if err = ssdb.ClearAll(); err != nil { + t.Error("clear all err") + } + if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + t.Error("check err") + } +} diff --git a/pkg/config.go b/pkg/config.go new file mode 100644 index 00000000..b6c9a99c --- /dev/null +++ b/pkg/config.go @@ -0,0 +1,524 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" + "github.com/astaxie/beego/utils" +) + +// Config is the main struct for BConfig +type Config struct { + AppName string //Application name + RunMode string //Running Mode: dev | prod + RouterCaseSensitive bool + ServerName string + RecoverPanic bool + RecoverFunc func(*context.Context) + CopyRequestBody bool + EnableGzip bool + MaxMemory int64 + EnableErrorsShow bool + EnableErrorsRender bool + Listen Listen + WebConfig WebConfig + Log LogConfig +} + +// Listen holds for http and https related config +type Listen struct { + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + AutoTLS bool + Domains []string + TLSCacheDir string + EnableHTTPS bool + EnableMutualHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + TrustCaFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O +} + +// WebConfig holds web related config +type WebConfig struct { + AutoRender bool + EnableDocs bool + FlashName string + FlashSeparator string + DirectoryIndex bool + StaticDir map[string]string + StaticExtensionsToGzip []string + StaticCacheFileSize int + StaticCacheFileNum int + TemplateLeft string + TemplateRight string + ViewsPath string + EnableXSRF bool + XSRFKey string + XSRFExpire int + Session SessionConfig +} + +// SessionConfig holds session related config +type SessionConfig struct { + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params +} + +// LogConfig holds Log related config +type LogConfig struct { + AccessLogs bool + EnableStaticLogs bool //log static files requests default: false + AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + FileLineNum bool + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + WorkPath, err = os.Getwd() + if err != nil { + panic(err) + } + var filename = "app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" + } + appConfigPath = filepath.Join(WorkPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} + +func recoverPanic(ctx *context.Context) { + if err := recover(); err != nil { + if err == ErrAbort { + return + } + if !BConfig.RecoverPanic { + panic(err) + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), ctx) + return + } + } + var stack string + logs.Critical("the request url is ", ctx.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { + showErr(err, ctx, stack) + } + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } + } +} + +func newBConfig() *Config { + return &Config{ + AppName: "beego", + RunMode: PROD, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + VERSION, + RecoverPanic: true, + RecoverFunc: recoverPanic, + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, //64MB + EnableErrorsShow: true, + EnableErrorsRender: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + StaticCacheFileSize: 1024 * 100, + StaticCacheFileNum: 1000, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + }, + }, + Log: LogConfig{ + AccessLogs: false, + EnableStaticLogs: false, + AccessLogsFormat: "APACHE_FORMAT", + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + } +} + +// now only support ini, next will support json. +func parseConfig(appConfigPath string) (err error) { + AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) + if err != nil { + return err + } + return assignConfig(AppConfig) +} + +func assignConfig(ac config.Configer) error { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + // set the run mode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runMode := ac.String("RunMode"); runMode != "" { + BConfig.RunMode = runMode + } + + if sd := ac.String("StaticDir"); sd != "" { + BConfig.WebConfig.StaticDir = map[string]string{} + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + BConfig.WebConfig.StaticCacheFileSize = sfs + } + + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + BConfig.WebConfig.StaticCacheFileNum = sfn + } + + if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) + los := strings.Split(lo, ";") + for _, v := range los { + if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { + BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] + } else { + continue + } + } + } + + //init log + logs.Reset() + for adaptor, config := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, config) + if err != nil { + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + + return nil +} + +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(ac.DefaultInt64(name, pf.Int())) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + //do nothing here + } + } + +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + return err + } + + if !utils.FileExists(absConfigPath) { + return fmt.Errorf("the target config file: %s don't exist", configPath) + } + + appConfigPath = absConfigPath + appConfigProvider = adapterName + + return parseConfig(appConfigPath) +} + +type beegoAppConfig struct { + innerConfig config.Configer +} + +func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(appConfigProvider, appConfigPath) + if err != nil { + return nil, err + } + return &beegoAppConfig{ac}, nil +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { + return v + } + return b.innerConfig.String(key) +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { + return v + } + return b.innerConfig.Strings(key) +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int(key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int64(key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Bool(key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Float(key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 00000000..bfd79e85 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,242 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +//Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +//More docs http://beego.me/docs/module/config.md +package config + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error //support section::key type in given key when using ini type. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string //get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string //get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + if adapter == nil { + panic("config: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("config: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.Parse(filename) +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.ParseData(data) +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + switch value := v.(type) { + case string: + m[k] = ExpandValueEnv(value) + case map[string]interface{}: + m[k] = ExpandValueEnvForMap(value) + case map[string]string: + for k2, v2 := range value { + value[k2] = ExpandValueEnv(v2) + } + m[k] = value + } + } + return m +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) (realValue string) { + realValue = value + + vLen := len(value) + // 3 = ${} + if vLen < 3 { + return + } + // Need start with "${" and end with "}", then return. + if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { + return + } + + key := "" + defaultV := "" + // value start with "${" + for i := 2; i < vLen; i++ { + if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { + key = value[2:i] + defaultV = value[i+2 : vLen-1] // other string is default value. + break + } else if value[i] == '}' { + key = value[2:i] + break + } + } + + realValue = os.Getenv(key) + if realValue == "" { + realValue = defaultV + } + + return +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + if val != nil { + switch v := val.(type) { + case bool: + return v, nil + case string: + switch v { + case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": + return false, nil + } + case int8, int32, int64: + strV := fmt.Sprintf("%d", v) + if strV == "1" { + return true, nil + } else if strV == "0" { + return false, nil + } + case float64: + if v == 1.0 { + return true, nil + } else if v == 0.0 { + return false, nil + } + } + return false, fmt.Errorf("parsing %q: invalid syntax", val) + } + return false, fmt.Errorf("parsing : invalid syntax") +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go new file mode 100644 index 00000000..34f094fe --- /dev/null +++ b/pkg/config/env/env.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/utils" +) + +var env *utils.BeeMap + +func init() { + env = utils.NewBeeMap() + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Set(splits[0], os.Getenv(splits[0])) + } +} + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val := env.Get(key); val != nil { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val := env.Get(key); val != nil { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Set(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + items := env.Items() + envs := make(map[string]string, env.Count()) + + for key, val := range items { + switch key := key.(type) { + case string: + switch val := val.(type) { + case string: + envs[key] = val + } + } + } + return envs +} diff --git a/pkg/config/env/env_test.go b/pkg/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/config/fake.go b/pkg/config/fake.go new file mode 100644 index 00000000..d21ab820 --- /dev/null +++ b/pkg/config/fake.go @@ -0,0 +1,134 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "strconv" + "strings" +) + +type fakeConfigContainer struct { + data map[string]string +} + +func (c *fakeConfigContainer) getData(key string) string { + return c.data[strings.ToLower(key)] +} + +func (c *fakeConfigContainer) Set(key, val string) error { + c.data[strings.ToLower(key)] = val + return nil +} + +func (c *fakeConfigContainer) String(key string) string { + return c.getData(key) +} + +func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getData(key), 10, 64) +} + +func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getData(key), 64) +} + +func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return nil, errors.New("key not find") +} + +func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { + return nil, errors.New("not implement in the fakeConfigContainer") +} + +func (c *fakeConfigContainer) SaveConfigFile(filename string) error { + return errors.New("not implement in the fakeConfigContainer") +} + +var _ Configer = new(fakeConfigContainer) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + return &fakeConfigContainer{ + data: make(map[string]string), + } +} diff --git a/pkg/config/ini.go b/pkg/config/ini.go new file mode 100644 index 00000000..002e5e05 --- /dev/null +++ b/pkg/config/ini.go @@ -0,0 +1,504 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "io" + "io/ioutil" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" +) + +var ( + defaultSection = "default" // default section means if some ini items not in a section, make them in default section, + bNumComment = []byte{'#'} // number signal + bSemComment = []byte{';'} // semicolon signal + bEmpty = []byte{} + bEqual = []byte{'='} // equal signal + bDQuote = []byte{'"'} // quote signal + sectionStart = []byte{'['} // section start signal + sectionEnd = []byte{']'} // section end signal + lineBreak = "\n" +) + +// IniConfig implements Config to parse ini file. +type IniConfig struct { +} + +// Parse creates a new Config and parses the file configuration from the named file. +func (ini *IniConfig) Parse(name string) (Configer, error) { + return ini.parseFile(name) +} + +func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + + return ini.parseData(filepath.Dir(name), data) +} + +func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { + cfg := &IniConfigContainer{ + data: make(map[string]map[string]string), + sectionComment: make(map[string]string), + keyComment: make(map[string]string), + RWMutex: sync.RWMutex{}, + } + cfg.Lock() + defer cfg.Unlock() + + var comment bytes.Buffer + buf := bufio.NewReader(bytes.NewBuffer(data)) + // check the BOM + head, err := buf.Peek(3) + if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { + for i := 1; i <= 3; i++ { + buf.ReadByte() + } + } + section := defaultSection + tmpBuf := bytes.NewBuffer(nil) + for { + tmpBuf.Reset() + + shouldBreak := false + for { + tmp, isPrefix, err := buf.ReadLine() + if err == io.EOF { + shouldBreak = true + break + } + + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } + + tmpBuf.Write(tmp) + if isPrefix { + continue + } + + if !isPrefix { + break + } + } + if shouldBreak { + break + } + + line := tmpBuf.Bytes() + line = bytes.TrimSpace(line) + if bytes.Equal(line, bEmpty) { + continue + } + var bComment []byte + switch { + case bytes.HasPrefix(line, bNumComment): + bComment = bNumComment + case bytes.HasPrefix(line, bSemComment): + bComment = bSemComment + } + if bComment != nil { + line = bytes.TrimLeft(line, string(bComment)) + // Need append to a new line if multi-line comments. + if comment.Len() > 0 { + comment.WriteByte('\n') + } + comment.Write(line) + continue + } + + if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { + section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive + if comment.Len() > 0 { + cfg.sectionComment[section] = comment.String() + comment.Reset() + } + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + continue + } + + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + keyValue := bytes.SplitN(line, bEqual, 2) + + key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive + key = strings.ToLower(key) + + // handle include "other.conf" + if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + + includefiles := strings.Fields(key) + if includefiles[0] == "include" && len(includefiles) == 2 { + + otherfile := strings.Trim(includefiles[1], "\"") + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(dir, otherfile) + } + + i, err := ini.parseFile(otherfile) + if err != nil { + return nil, err + } + + for sec, dt := range i.data { + if _, ok := cfg.data[sec]; !ok { + cfg.data[sec] = make(map[string]string) + } + for k, v := range dt { + cfg.data[sec][k] = v + } + } + + for sec, comm := range i.sectionComment { + cfg.sectionComment[sec] = comm + } + + for k, comm := range i.keyComment { + cfg.keyComment[k] = comm + } + + continue + } + } + + if len(keyValue) != 2 { + return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") + } + val := bytes.TrimSpace(keyValue[1]) + if bytes.HasPrefix(val, bDQuote) { + val = bytes.Trim(val, `"`) + } + + cfg.data[section][key] = ExpandValueEnv(string(val)) + if comment.Len() > 0 { + cfg.keyComment[section+"."+key] = comment.String() + comment.Reset() + } + + } + return cfg, nil +} + +// ParseData parse ini the data +// When include other.conf,other.conf is either absolute directory +// or under beego in default temporary directory(/tmp/beego[-username]). +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, err + } + + return ini.parseData(dir, data) +} + +// IniConfigContainer A Config represents the ini configuration. +// When set and get value, support key as section:name type. +type IniConfigContainer struct { + data map[string]map[string]string // section=> key:val + sectionComment map[string]string // section : comment + keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *IniConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getdata(key)) +} + +// DefaultBool returns the boolean value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *IniConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getdata(key)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *IniConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getdata(key), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *IniConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getdata(key), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *IniConfigContainer) String(key string) string { + return c.getdata(key) +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *IniConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v, nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file. +// +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. +func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + // Get section or key comments. Fixed #1607 + getCommentStr := func(section, key string) string { + var ( + comment string + ok bool + ) + if len(key) == 0 { + comment, ok = c.sectionComment[section] + } else { + comment, ok = c.keyComment[section+"."+key] + } + + if ok { + // Empty comment + if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { + return string(bNumComment) + } + prefix := string(bNumComment) + // Add the line head character "#" + return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) + } + return "" + } + + buf := bytes.NewBuffer(nil) + // Save default section at first place + if dt, ok := c.data[defaultSection]; ok { + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(defaultSection, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + // Save named sections + for section, dt := range c.data { + if section != defaultSection { + // Write section comments. + if v := getCommentStr(section, ""); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write section name. + if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { + return err + } + + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(section, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + } + _, err = buf.WriteTo(f) + return err +} + +// Set writes a new value for key. +// if write to one section, the key need be "section::key". +// if the section is not existed, it panics. +func (c *IniConfigContainer) Set(key, value string) error { + c.Lock() + defer c.Unlock() + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + c.data[section][k] = value + return nil +} + +// DIY returns the raw value by a given key. +func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return v, errors.New("key not find") +} + +// section.key or key +func (c *IniConfigContainer) getdata(key string) string { + if len(key) == 0 { + return "" + } + c.RLock() + defer c.RUnlock() + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + if v, ok := c.data[section]; ok { + if vv, ok := v[k]; ok { + return vv + } + } + return "" +} + +func init() { + Register("ini", &IniConfig{}) +} diff --git a/pkg/config/ini_test.go b/pkg/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/config/json.go b/pkg/config/json.go new file mode 100644 index 00000000..c4ef25cd --- /dev/null +++ b/pkg/config/json.go @@ -0,0 +1,269 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" +) + +// JSONConfig is a json config parser and implements Config interface. +type JSONConfig struct { +} + +// Parse returns a ConfigContainer with parsed json config map. +func (js *JSONConfig) Parse(filename string) (Configer, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + content, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return js.ParseData(content) +} + +// ParseData returns a ConfigContainer with json string +func (js *JSONConfig) ParseData(data []byte) (Configer, error) { + x := &JSONConfigContainer{ + data: make(map[string]interface{}), + } + err := json.Unmarshal(data, &x.data) + if err != nil { + var wrappingArray []interface{} + err2 := json.Unmarshal(data, &wrappingArray) + if err2 != nil { + return nil, err + } + x.data["rootArray"] = wrappingArray + } + + x.data = ExpandValueEnvForMap(x.data) + + return x, nil +} + +// JSONConfigContainer A Config represents the json configuration. +// Only when get value, support key as section:name type. +type JSONConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *JSONConfigContainer) Bool(key string) (bool, error) { + val := c.getData(key) + if val != nil { + return ParseBool(val) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { + if v, err := c.Bool(key); err == nil { + return v + } + return defaultval +} + +// Int returns the integer value for a given key. +func (c *JSONConfigContainer) Int(key string) (int, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int(v), nil + } else if v, ok := val.(string); ok { + return strconv.Atoi(v) + } + return 0, errors.New("not valid value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { + if v, err := c.Int(key); err == nil { + return v + } + return defaultval +} + +// Int64 returns the int64 value for a given key. +func (c *JSONConfigContainer) Int64(key string) (int64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int64(v), nil + } + return 0, errors.New("not int64 value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + if v, err := c.Int64(key); err == nil { + return v + } + return defaultval +} + +// Float returns the float value for a given key. +func (c *JSONConfigContainer) Float(key string) (float64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return v, nil + } + return 0.0, errors.New("not float64 value") + } + return 0.0, errors.New("not exist key:" + key) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + if v, err := c.Float(key); err == nil { + return v + } + return defaultval +} + +// String returns the string value for a given key. +func (c *JSONConfigContainer) String(key string) string { + val := c.getData(key) + if val != nil { + if v, ok := val.(string); ok { + return v + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { + // TODO FIXME should not use "" to replace non existence + if v := c.String(key); v != "" { + return v + } + return defaultval +} + +// Strings returns the []string value for a given key. +func (c *JSONConfigContainer) Strings(key string) []string { + stringVal := c.String(key) + if stringVal == "" { + return nil + } + return strings.Split(c.String(key), ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { + if v := c.Strings(key); v != nil { + return v + } + return defaultval +} + +// GetSection returns map for the given section +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("nonexist section " + section) +} + +// SaveConfigFile save the config into file +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := json.MarshalIndent(c.data, "", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *JSONConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { + val := c.getData(key) + if val != nil { + return val, nil + } + return nil, errors.New("not exist key") +} + +// section.key or key +func (c *JSONConfigContainer) getData(key string) interface{} { + if len(key) == 0 { + return nil + } + + c.RLock() + defer c.RUnlock() + + sectionKeys := strings.Split(key, "::") + if len(sectionKeys) >= 2 { + curValue, ok := c.data[sectionKeys[0]] + if !ok { + return nil + } + for _, key := range sectionKeys[1:] { + if v, ok := curValue.(map[string]interface{}); ok { + if curValue, ok = v[key]; !ok { + return nil + } + } + } + return curValue + } + if v, ok := c.data[key]; ok { + return v + } + return nil +} + +func init() { + Register("json", &JSONConfig{}) +} diff --git a/pkg/config/json_test.go b/pkg/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go new file mode 100644 index 00000000..494242d3 --- /dev/null +++ b/pkg/config/xml/xml.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +//More docs http://beego.me/docs/module/config.md +package xml + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/x2j" +) + +// Config is a xml config parser and implements Config interface. +// xml configurations should be included in tag. +// only support key/value pair as value as each item. +type Config struct{} + +// Parse returns a ConfigContainer with parsed xml config map. +func (xc *Config) Parse(filename string) (config.Configer, error) { + context, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + return xc.ParseData(context) +} + +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { + x := &ConfigContainer{data: make(map[string]interface{})} + + d, err := x2j.DocToMap(string(data)) + if err != nil { + return nil, err + } + + x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + + return x, nil +} + +// ConfigContainer A Config represents the xml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.Mutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + if v := c.data[key]; v != nil { + return config.ParseBool(v) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.data[key].(string)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.data[key].(string), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v + +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.data[key].(string), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, ok := c.data[key].(string); ok { + return v + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil + } + return nil, fmt.Errorf("section '%s' not found", section) +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := xml.MarshalIndent(c.data, " ", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[key]; ok { + return v, nil + } + return nil, errors.New("not exist key") +} + +func init() { + config.Register("xml", &Config{}) +} diff --git a/pkg/config/xml/xml_test.go b/pkg/config/xml/xml_test.go new file mode 100644 index 00000000..346c866e --- /dev/null +++ b/pkg/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go new file mode 100644 index 00000000..5def2da3 --- /dev/null +++ b/pkg/config/yaml/yaml.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +//More docs http://beego.me/docs/module/config.md +package yaml + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/goyaml2" +) + +// Config is a yaml config parser and implements Config interface. +type Config struct{} + +// Parse returns a ConfigContainer with parsed yaml config map. +func (yaml *Config) Parse(filename string) (y config.Configer, err error) { + cnf, err := ReadYmlReader(filename) + if err != nil { + return + } + y = &ConfigContainer{ + data: cnf, + } + return +} + +// ParseData parse yaml data +func (yaml *Config) ParseData(data []byte) (config.Configer, error) { + cnf, err := parseYML(data) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: cnf, + }, nil +} + +// ReadYmlReader Read yaml file to map. +// if json like, use json package, unless goyaml2 package. +func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { + buf, err := ioutil.ReadFile(path) + if err != nil { + return + } + + return parseYML(buf) +} + +// parseYML parse yaml formatted []byte to map. +func parseYML(buf []byte) (cnf map[string]interface{}, err error) { + if len(buf) < 3 { + return + } + + if string(buf[0:1]) == "{" { + log.Println("Look like a Json, try json umarshal") + err = json.Unmarshal(buf, &cnf) + if err == nil { + log.Println("It is Json Map") + return + } + } + + data, err := goyaml2.Read(bytes.NewReader(buf)) + if err != nil { + log.Println("Goyaml2 ERR>", string(buf), err) + return + } + + if data == nil { + log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) + return + } + cnf, ok := data.(map[string]interface{}) + if !ok { + log.Println("Not a Map? >> ", string(buf), data) + cnf = nil + } + cnf = config.ExpandValueEnvForMap(cnf) + return +} + +// ConfigContainer A Config represents the yaml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + v, err := c.getData(key) + if err != nil { + return false, err + } + return config.ParseBool(v) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int); ok { + return vv, nil + } else if vv, ok := v.(int64); ok { + return int(vv), nil + } + return 0, errors.New("not int value") +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int64); ok { + return vv, nil + } + return 0, errors.New("not bool value") +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + if v, err := c.getData(key); err != nil { + return 0.0, err + } else if vv, ok := v.(float64); ok { + return vv, nil + } else if vv, ok := v.(int); ok { + return float64(vv), nil + } else if vv, ok := v.(int64); ok { + return float64(vv), nil + } + return 0.0, errors.New("not float64 value") +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, err := c.getData(key); err == nil { + if vv, ok := v.(string); ok { + return vv + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + err = goyaml2.Write(f, c.data) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + return c.getData(key) +} + +func (c *ConfigContainer) getData(key string) (interface{}, error) { + + if len(key) == 0 { + return nil, errors.New("key is empty") + } + c.RLock() + defer c.RUnlock() + + keys := strings.Split(key, ".") + tmpData := c.data + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys) - 1 { + return tmpData, nil + } + } + default: + { + return v, nil + } + + } + } + } + return nil, fmt.Errorf("not exist key %q", key) +} + +func init() { + config.Register("yaml", &Config{}) +} diff --git a/pkg/config/yaml/yaml_test.go b/pkg/config/yaml/yaml_test.go new file mode 100644 index 00000000..49cc1d1e --- /dev/null +++ b/pkg/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} diff --git a/pkg/config_test.go b/pkg/config_test.go new file mode 100644 index 00000000..5f71f1c3 --- /dev/null +++ b/pkg/config_test.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestDefaults(t *testing.T) { + if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { + t.Errorf("FlashName was not set to default.") + } + + if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { + t.Errorf("FlashName was not set to default.") + } +} + +func TestAssignConfig_01(t *testing.T) { + _BConfig := &Config{} + _BConfig.AppName = "beego_test" + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) + assignSingleConfig(_BConfig, ac) + if _BConfig.AppName != "beego_json" { + t.Log(_BConfig) + t.FailNow() + } +} + +func TestAssignConfig_02(t *testing.T) { + _BConfig := &Config{} + bs, _ := json.Marshal(newBConfig()) + + jsonMap := M{} + json.Unmarshal(bs, &jsonMap) + + configMap := M{} + for k, v := range jsonMap { + if reflect.TypeOf(v).Kind() == reflect.Map { + for k1, v1 := range v.(M) { + if reflect.TypeOf(v1).Kind() == reflect.Map { + for k2, v2 := range v1.(M) { + configMap[k2] = v2 + } + } else { + configMap[k1] = v1 + } + } + } else { + configMap[k] = v + } + } + configMap["MaxMemory"] = 1024 + configMap["Graceful"] = true + configMap["XSRFExpire"] = 32 + configMap["SessionProviderConfig"] = "file" + configMap["FileLineNum"] = true + + jcf := &config.JSONConfig{} + bs, _ = json.Marshal(configMap) + ac, _ := jcf.ParseData(bs) + + for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + if _BConfig.MaxMemory != 1024 { + t.Log(_BConfig.MaxMemory) + t.FailNow() + } + + if !_BConfig.Listen.Graceful { + t.Log(_BConfig.Listen.Graceful) + t.FailNow() + } + + if _BConfig.WebConfig.XSRFExpire != 32 { + t.Log(_BConfig.WebConfig.XSRFExpire) + t.FailNow() + } + + if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { + t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) + t.FailNow() + } + + if !_BConfig.Log.FileLineNum { + t.Log(_BConfig.Log.FileLineNum) + t.FailNow() + } + +} + +func TestAssignConfig_03(t *testing.T) { + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") + assignConfig(ac) + + t.Logf("%#v", BConfig) + + if BConfig.AppName != "test_app" { + t.FailNow() + } + + if BConfig.RunMode != "online" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download"] != "down" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download2"] != "down2" { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileSize != 87456 { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileNum != 1254 { + t.FailNow() + } + if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { + t.FailNow() + } +} diff --git a/pkg/context/acceptencoder.go b/pkg/context/acceptencoder.go new file mode 100644 index 00000000..b4e2492c --- /dev/null +++ b/pkg/context/acceptencoder.go @@ -0,0 +1,232 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + +type resetWriter interface { + io.Writer + Reset(w io.Writer) +} + +type nopResetWriter struct { + io.Writer +} + +func (n nopResetWriter) Reset(w io.Writer) { + //do nothing +} + +type acceptEncoder struct { + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool +} + +func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return nopResetWriter{wr} + } + var rwr resetWriter + switch level { + case flate.BestSpeed: + rwr = ac.customCompressLevelPool.Get().(resetWriter) + case flate.BestCompression: + rwr = ac.bestCompressionPool.Get().(resetWriter) + default: + rwr = ac.levelEncode(level) + } + rwr.Reset(wr) + return rwr +} + +func (ac acceptEncoder) put(wr resetWriter, level int) { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return + } + wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + + switch level { + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) + case flate.BestCompression: + ac.bestCompressionPool.Put(wr) + } +} + +var ( + noneCompressEncoder = acceptEncoder{"", nil, nil, nil} + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } + + //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed + //deflate + //The "zlib" format defined in RFC 1950 [31] in combination with + //the "deflate" compression mechanism described in RFC 1951 [29]. + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter resetWriter + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter = ce.encode(writer, level) + defer ce.put(outputWriter, level) + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + var cf acceptEncoder + var ok bool + if cf, ok = encoderMap[vs[0]]; !ok { + continue + } + if len(vs) == 1 { + return cf.name + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{cf.name, f} + } + } + } + return lastQ.name +} diff --git a/pkg/context/acceptencoder_test.go b/pkg/context/acceptencoder_test.go new file mode 100644 index 00000000..e3d61e27 --- /dev/null +++ b/pkg/context/acceptencoder_test.go @@ -0,0 +1,59 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { + t.Fail() + } +} diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 00000000..de248ed2 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,263 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/utils" +) + +//commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context struct { + Input *BeegoInput + Output *BeegoOutput + Request *http.Request + ResponseWriter *Response + _xsrfToken string +} + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + if ctx.ResponseWriter == nil { + ctx.ResponseWriter = &Response{} + } + ctx.ResponseWriter.reset(rw) + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) + ctx._xsrfToken = "" +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) + panic(body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + ctx.ResponseWriter.Write([]byte(content)) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return ctx.Input.Cookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + ctx.Output.Cookie(name, value, others...) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { + token, ok := ctx.GetSecureCookie(key, "_xsrf") + if !ok { + token = string(utils.RandomCreateBytes(32)) + ctx.SetSecureCookie(key, "_xsrf", token, expire) + } + ctx._xsrfToken = token + } + return ctx._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + token := ctx.Input.Query("_xsrf") + if token == "" { + token = ctx.Request.Header.Get("X-Xsrftoken") + } + if token == "" { + token = ctx.Request.Header.Get("X-Csrftoken") + } + if token == "" { + ctx.Abort(422, "422") + return false + } + if ctx._xsrfToken != token { + ctx.Abort(417, "417") + return false + } + return true +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int + Elapsed time.Duration +} + +func (r *Response) reset(rw http.ResponseWriter) { + r.ResponseWriter = rw + r.Status = 0 + r.Started = false +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { + //prevent multiple response.WriteHeader calls + return + } + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/pkg/context/context_test.go b/pkg/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/pkg/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/pkg/context/input.go b/pkg/context/input.go new file mode 100644 index 00000000..385549c1 --- /dev/null +++ b/pkg/context/input.go @@ -0,0 +1,689 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" +) + +// Regexes for checking the accept headers +// TODO make sure these are correct +var ( + acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) + acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) + acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) + maxParam = 50 +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput struct { + Context *Context + CruSession session.Store + pnames []string + pvalues []string + data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex + RequestBody []byte + RunMethod string + RunController reflect.Type +} + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return &BeegoInput{ + pnames: make([]string, 0, maxParam), + pvalues: make([]string, 0, maxParam), + data: make(map[interface{}]interface{}), + } +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + input.Context = ctx + input.CruSession = nil + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] + input.dataLock.Lock() + input.data = nil + input.dataLock.Unlock() + input.RequestBody = []byte{} +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return input.Context.Request.Proto +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return input.Context.Request.URL.EscapedPath() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return input.Scheme() + "://" + input.Domain() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if input.Context.Request.URL.Scheme != "" { + return input.Context.Request.URL.Scheme + } + if input.Context.Request.TLS == nil { + return "http" + } + return "https" +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return input.Host() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + if input.Context.Request.Host != "" { + if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + return hostPart + } + return input.Context.Request.Host + } + return "localhost" +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return input.Context.Request.Method +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return input.Method() == method +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return input.Is("GET") +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return input.Is("POST") +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return input.Is("HEAD") +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return input.Is("OPTIONS") +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return input.Is("PUT") +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return input.Is("DELETE") +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return input.Is("PATCH") +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return input.Header("X-Requested-With") == "XMLHttpRequest" +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return input.Scheme() == "https" +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return input.Header("Upgrade") == "websocket" +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return strings.Contains(input.Header("Content-Type"), "multipart/form-data") +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return acceptsHTMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return acceptsXMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return acceptsJSONRegex.MatchString(input.Header("Accept")) +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return acceptsYAMLRegex.MatchString(input.Header("Accept")) +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + ips := input.Proxy() + if len(ips) > 0 && ips[0] != "" { + rip, _, err := net.SplitHostPort(ips[0]) + if err != nil { + rip = ips[0] + } + return rip + } + if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { + return ip + } + return input.Context.Request.RemoteAddr +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + if ips := input.Header("X-Forwarded-For"); ips != "" { + return strings.Split(ips, ",") + } + return []string{} +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return input.Header("Referer") +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return input.Referer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + parts := strings.Split(input.Host(), ".") + if len(parts) >= 3 { + return strings.Join(parts[:len(parts)-2], ".") + } + return "" +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + port, _ := strconv.Atoi(portPart) + return port + } + return 80 +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return input.Header("User-Agent") +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return len(input.pnames) +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + // we cannot use url.PathEscape(input.pvalues[i]) + // for example, if the value is /a/b + // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb + // However, the value is used in ControllerRegister.ServeHTTP + // and split by "/", so function crash... + return input.pvalues[i] + } + } + return "" +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + m := make(map[string]string) + for i, v := range input.pnames { + if i <= len(input.pvalues) { + m[v] = input.pvalues[i] + } + } + return m +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + // check if already exists + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + input.pvalues[i] = val + return + } + } + input.pvalues = append(input.pvalues, val) + input.pnames = append(input.pnames, key) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + if val := input.Param(key); val != "" { + return val + } + if input.Context.Request.Form == nil { + input.dataLock.Lock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + input.dataLock.Unlock() + } + input.dataLock.RLock() + defer input.dataLock.RUnlock() + return input.Context.Request.Form.Get(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return input.Context.Request.Header.Get(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + ck, err := input.Context.Request.Cookie(key) + if err != nil { + return "" + } + return ck.Value +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return input.CruSession.Get(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + if input.Context.Request.Body == nil { + return []byte{} + } + + var requestbody []byte + safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} + if input.Header("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(safe) + if err != nil { + return nil + } + requestbody, _ = ioutil.ReadAll(reader) + } else { + requestbody, _ = ioutil.ReadAll(safe) + } + + input.Context.Request.Body.Close() + bf := bytes.NewBuffer(requestbody) + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) + input.RequestBody = requestbody + return requestbody +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + return input.data +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if v, ok := input.data[key]; ok { + return v + } + return nil +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + input.data[key] = val +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + // Parse the body depending on the content type. + if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + } else if err := input.Context.Request.ParseForm(); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + return nil +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { + return errors.New("beego: non-pointer passed to Bind: " + key) + } + value = value.Elem() + if !value.CanSet() { + return errors.New("beego: non-settable variable passed to Bind: " + key) + } + typ := value.Type() + // Get real type if dest define with interface{}. + // e.g var dest interface{} dest=1.0 + if value.Kind() == reflect.Interface { + typ = value.Elem().Type() + } + rv := input.bind(key, typ) + if !rv.IsValid() { + return errors.New("beego: reflect value is empty") + } + value.Set(rv) + return nil +} + +func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindFloat(val, typ) + case reflect.String: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindString(val, typ) + case reflect.Bool: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&input.Context.Request.Form, key, typ) + case reflect.Struct: + rv = input.bindStruct(&input.Context.Request.Form, key, typ) + case reflect.Ptr: + rv = input.bindPoint(key, typ) + case reflect.Map: + rv = input.bindMap(&input.Context.Request.Form, key, typ) + } + return rv +} + +func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + rv = input.bindFloat(val, typ) + case reflect.String: + rv = input.bindString(val, typ) + case reflect.Bool: + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&url.Values{"": {val}}, "", typ) + case reflect.Struct: + rv = input.bindStruct(&url.Values{"": {val}}, "", typ) + case reflect.Ptr: + rv = input.bindPoint(val, typ) + case reflect.Map: + rv = input.bindMap(&url.Values{"": {val}}, "", typ) + } + return rv +} + +func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { + intValue, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetInt(intValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { + uintValue, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetUint(uintValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { + floatValue, err := strconv.ParseFloat(val, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetFloat(floatValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { + return reflect.ValueOf(val) +} + +func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "true", "on", "1": + return reflect.ValueOf(true) + } + return reflect.ValueOf(false) +} + +type sliceValue struct { + index int // Index extracted from brackets. If -1, no index was provided. + value reflect.Value // the bound value for this slice element. +} + +func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { + maxIndex := -1 + numNoIndex := 0 + sliceValues := []sliceValue{} + for reqKey, vals := range *params { + if !strings.HasPrefix(reqKey, key+"[") { + continue + } + // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) + index := -1 + leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) + if rightBracket > leftBracket+1 { + index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) + } + subKeyIndex := rightBracket + 1 + + // Handle the indexed case. + if index > -1 { + if index > maxIndex { + maxIndex = index + } + sliceValues = append(sliceValues, sliceValue{ + index: index, + value: input.bind(reqKey[:subKeyIndex], typ.Elem()), + }) + continue + } + + // It's an un-indexed element. (e.g. element[]) + numNoIndex += len(vals) + for _, val := range vals { + // Unindexed values can only be direct-bound. + sliceValues = append(sliceValues, sliceValue{ + index: -1, + value: input.bindValue(val, typ.Elem()), + }) + } + } + resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) + for _, sv := range sliceValues { + if sv.index != -1 { + resultArray.Index(sv.index).Set(sv.value) + } else { + resultArray = reflect.Append(resultArray, sv.value) + } + } + return resultArray +} + +func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { + result := reflect.New(typ).Elem() + fieldValues := make(map[string]reflect.Value) + for reqKey, val := range *params { + var fieldName string + if strings.HasPrefix(reqKey, key+".") { + fieldName = reqKey[len(key)+1:] + } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { + fieldName = reqKey[len(key)+1 : len(reqKey)-1] + } else { + continue + } + + if _, ok := fieldValues[fieldName]; !ok { + // Time to bind this field. Get it and make sure we can set it. + fieldValue := result.FieldByName(fieldName) + if !fieldValue.IsValid() { + continue + } + if !fieldValue.CanSet() { + continue + } + boundVal := input.bindValue(val[0], fieldValue.Type()) + fieldValue.Set(boundVal) + fieldValues[fieldName] = boundVal + } + } + + return result +} + +func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { + return input.bind(key, typ.Elem()).Addr() +} + +func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { + var ( + result = reflect.MakeMap(typ) + keyType = typ.Key() + valueType = typ.Elem() + ) + for paramName, values := range *params { + if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { + continue + } + + key := paramName[len(key)+1 : len(paramName)-1] + result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) + } + return result +} diff --git a/pkg/context/input_test.go b/pkg/context/input_test.go new file mode 100644 index 00000000..3a6c2e7b --- /dev/null +++ b/pkg/context/input_test.go @@ -0,0 +1,217 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestBind(t *testing.T) { + type testItem struct { + field string + empty interface{} + want interface{} + } + type Human struct { + ID int + Nick string + Pwd string + Ms bool + } + + cases := []struct { + request string + valueGp []testItem + }{ + {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, + + {"/?p=", []testItem{{"p", "", ""}}}, + {"/?p=str", []testItem{{"p", "", "str"}}}, + + {"/?p=123", []testItem{{"p", 0, 123}}}, + {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, + + {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + + {"/?p=true", []testItem{{"p", false, true}}}, + {"/?p=ON", []testItem{{"p", false, true}}}, + {"/?p=on", []testItem{{"p", false, true}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + {"/?p=2", []testItem{{"p", false, false}}}, + {"/?p=false", []testItem{{"p", false, false}}}, + + {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, + {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, + + {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, + {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, + + {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, + + {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, + + {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, + {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, + {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + []testItem{{"human", []Human{}, []Human{ + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + }}}}, + + { + "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", + []testItem{ + {"id", 0, 123}, + {"isok", false, true}, + {"ft", 0.0, 1.2}, + {"ol", []int{}, []int{1, 2}}, + {"ul", []string{}, []string{"str", "array"}}, + {"human", Human{}, Human{Nick: "astaxie"}}, + }, + }, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + for _, item := range c.valueGp { + got := item.empty + err := beegoInput.Bind(&got, item.field) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, item.want) { + t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) + } + } + + } +} + +func TestSubDomain(t *testing.T) { + r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + subdomain := beegoInput.SubDomains() + if subdomain != "www" { + t.Fatal("Subdomain parse error, got" + subdomain) + } + + r, _ = http.NewRequest("GET", "http://localhost/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + /* TODO Fix this + r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + */ + + r, _ = http.NewRequest("GET", "http://example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb.cc.dd" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } +} + +func TestParams(t *testing.T) { + inp := NewInput() + + inp.SetParam("p1", "val1_ver1") + inp.SetParam("p2", "val2_ver1") + inp.SetParam("p3", "val3_ver1") + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") + } + if val := inp.Param("p3"); val != "val3_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") + } + vals := inp.Params() + expected := map[string]string{ + "p1": "val1_ver1", + "p2": "val2_ver1", + "p3": "val3_ver1", + } + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + // overwriting existing params + inp.SetParam("p1", "val1_ver2") + inp.SetParam("p2", "val2_ver2") + expected = map[string]string{ + "p1": "val1_ver2", + "p2": "val2_ver2", + "p3": "val3_ver1", + } + vals = inp.Params() + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + + if val := inp.Param("p2"); val != "val2_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + +} +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} diff --git a/pkg/context/output.go b/pkg/context/output.go new file mode 100644 index 00000000..238dcf45 --- /dev/null +++ b/pkg/context/output.go @@ -0,0 +1,408 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "html/template" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + yaml "gopkg.in/yaml.v2" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput struct { + Context *Context + Status int + EnableGzip bool +} + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return &BeegoOutput{} +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + output.Context = ctx + output.Status = 0 +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + output.Context.ResponseWriter.Header().Set(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) + output.Header("Content-Length", strconv.Itoa(buf.Len())) + } else { + output.Header("Content-Length", strconv.Itoa(len(content))) + } + // Write status code if it has been set manually + // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" + if output.Status != 0 { + output.Context.ResponseWriter.WriteHeader(output.Status) + output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true + } + io.Copy(output.Context.ResponseWriter, buf) + return nil +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) + + //fix cookie not work in IE + if len(others) > 0 { + var maxAge int64 + + switch v := others[0].(type) { + case int: + maxAge = int64(v) + case int32: + maxAge = int64(v) + case int64: + maxAge = v + } + + switch { + case maxAge > 0: + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + case maxAge < 0: + fmt.Fprintf(&b, "; Max-Age=0") + } + } + + // the settings below + // Path, Domain, Secure, HttpOnly + // can use nil skip set + + // default "/" + if len(others) > 1 { + if v, ok := others[1].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + } + } else { + fmt.Fprintf(&b, "; Path=%s", "/") + } + + // default empty + if len(others) > 2 { + if v, ok := others[2].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) + } + } + + // default empty + if len(others) > 3 { + var secure bool + switch v := others[3].(type) { + case bool: + secure = v + default: + if others[3] != nil { + secure = true + } + } + if secure { + fmt.Fprintf(&b, "; Secure") + } + } + + // default false. for session cookie default true + if len(others) > 4 { + if v, ok := others[4].(bool); ok && v { + fmt.Fprintf(&b, "; HttpOnly") + } + } + + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + +func sanitizeValue(v string) string { + return cookieValueSanitizer.Replace(v) +} + +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.Output.Body([]byte(err.Error())) + }) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + output.Header("Content-Type", "application/json; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + if encoding { + content = []byte(stringsToJSON(string(content))) + } + return output.Body(content) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + output.Header("Content-Type", "application/x-yaml; charset=utf-8") + var content []byte + var err error + content, err = yaml.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/javascript; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + callback := output.Context.Input.Query("callback") + if callback == "" { + return errors.New(`"callback" parameter required`) + } + callback = template.JSEscapeString(callback) + callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) + callbackContent.WriteString("(") + callbackContent.Write(content) + callbackContent.WriteString(");\r\n") + return output.Body(callbackContent.Bytes()) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/xml; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = xml.MarshalIndent(data, "", " ") + } else { + content, err = xml.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + accept := output.Context.Input.Header("Accept") + switch accept { + case ApplicationYAML: + output.YAML(data) + case ApplicationXML, TextXML: + output.XML(data, hasIndent) + default: + output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + } +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + + var fName string + if len(filename) > 0 && filename[0] != "" { + fName = filename[0] + } else { + fName = filepath.Base(file) + } + //https://tools.ietf.org/html/rfc6266#section-4.3 + fn := url.PathEscape(fName) + if fName == fn { + fn = "filename=" + fn + } else { + /** + The parameters "filename" and "filename*" differ only in that + "filename*" uses the encoding defined in [RFC5987], allowing the use + of characters not present in the ISO-8859-1 character set + ([ISO-8859-1]). + */ + fn = "filename=" + fName + "; filename*=utf-8''" + fn + } + output.Header("Content-Disposition", "attachment; "+fn) + output.Header("Content-Description", "File Transfer") + output.Header("Content-Type", "application/octet-stream") + output.Header("Content-Transfer-Encoding", "binary") + output.Header("Expires", "0") + output.Header("Cache-Control", "must-revalidate") + output.Header("Pragma", "public") + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + ctype := mime.TypeByExtension(ext) + if ctype != "" { + output.Header("Content-Type", ctype) + } +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + output.Status = status +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return output.Status >= 200 && output.Status < 300 || output.Status == 304 +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return output.Status == 201 || output.Status == 204 || output.Status == 304 +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return output.Status == 200 +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return output.Status >= 200 && output.Status < 300 +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return output.Status == 403 +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return output.Status == 404 +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return output.Status >= 400 && output.Status < 500 +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return output.Status >= 500 && output.Status < 600 +} + +func stringsToJSON(str string) string { + var jsons bytes.Buffer + for _, r := range str { + rint := int(r) + if rint < 128 { + jsons.WriteRune(r) + } else { + jsons.WriteString("\\u") + if rint < 0x100 { + jsons.WriteString("00") + } else if rint < 0x1000 { + jsons.WriteString("0") + } + jsons.WriteString(strconv.FormatInt(int64(rint), 16)) + } + } + return jsons.String() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + output.Context.Input.CruSession.Set(name, value) +} diff --git a/pkg/context/param/conv.go b/pkg/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/pkg/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/pkg/context/param/methodparams.go b/pkg/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/pkg/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/pkg/context/param/options.go b/pkg/context/param/options.go new file mode 100644 index 00000000..3d5ba013 --- /dev/null +++ b/pkg/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/pkg/context/param/parsers.go b/pkg/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/pkg/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/pkg/context/param/parsers_test.go b/pkg/context/param/parsers_test.go new file mode 100644 index 00000000..7065a28e --- /dev/null +++ b/pkg/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/pkg/context/renderer.go b/pkg/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/pkg/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/pkg/context/response.go b/pkg/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/pkg/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/controller.go b/pkg/controller.go new file mode 100644 index 00000000..0e8853b3 --- /dev/null +++ b/pkg/controller.go @@ -0,0 +1,706 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "fmt" + "html/template" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/session" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = errors.New("user stop run") + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = make(map[string][]ControllerComments) +) + +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + +// ControllerComments store the comment for the controller method +type ControllerComments struct { + Method string + Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments + AllowHTTPMethods []string + Params []map[string]string + MethodParams []*param.MethodParam +} + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice []ControllerComments + +func (p ControllerCommentsSlice) Len() int { return len(p) } +func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } +func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller struct { + // context data + Ctx *context.Context + Data map[interface{}]interface{} + + // route controller info + controllerName string + actionName string + methodMapping map[string]func() //method:routertree + AppController interface{} + + // template data + TplName string + ViewPath string + Layout string + LayoutSections map[string]string // the key is the section name and the value is the template name + TplPrefix string + TplExt string + EnableRender bool + + // xsrf data + _xsrfToken string + XSRFExpire int + EnableXSRF bool + + // session + CruSession session.Store +} + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface interface { + Init(ct *context.Context, controllerName, actionName string, app interface{}) + Prepare() + Get() + Post() + Delete() + Put() + Head() + Patch() + Options() + Trace() + Finish() + Render() error + XSRFToken() string + CheckXSRFCookie() bool + HandlerFunc(fn string) bool + URLMapping() +} + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + c.Layout = "" + c.TplName = "" + c.controllerName = controllerName + c.actionName = actionName + c.Ctx = ctx + c.TplExt = "tpl" + c.AppController = app + c.EnableRender = true + c.EnableXSRF = true + c.Data = ctx.Input.Data() + c.methodMapping = make(map[string]func()) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() {} + +// Finish runs after request function execution. +func (c *Controller) Finish() {} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + ts := func(h http.Header) (hs string) { + for k, v := range h { + hs += fmt.Sprintf("\r\n%s: %s", k, v) + } + return + } + hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) + c.Ctx.Output.Header("Content-Type", "message/http") + c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) + c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Ctx.WriteString(hs) +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + if v, ok := c.methodMapping[fnname]; ok { + v() + return true + } + return false +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() {} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + c.methodMapping[method] = fn +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + if !c.EnableRender { + return nil + } + rb, err := c.RenderBytes() + if err != nil { + return err + } + + if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + } + + return c.Ctx.Output.Body(rb) +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + b, e := c.RenderBytes() + return string(b), e +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + buf, err := c.renderTemplate() + //if the controller has set layout, then first get the tplName's content set the content to the layout + if err == nil && c.Layout != "" { + c.Data["LayoutContent"] = template.HTML(buf.String()) + + if c.LayoutSections != nil { + for sectionName, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + c.Data[sectionName] = "" + continue + } + buf.Reset() + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) + if err != nil { + return nil, err + } + c.Data[sectionName] = template.HTML(buf.String()) + } + } + + buf.Reset() + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) + } + return buf.Bytes(), err +} + +func (c *Controller) renderTemplate() (bytes.Buffer, error) { + var buf bytes.Buffer + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + } + if c.TplPrefix != "" { + c.TplName = c.TplPrefix + c.TplName + } + if BConfig.RunMode == DEV { + buildFiles := []string{c.TplName} + if c.Layout != "" { + buildFiles = append(buildFiles, c.Layout) + if c.LayoutSections != nil { + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue + } + buildFiles = append(buildFiles, sectionTpl) + } + } + } + BuildTemplate(c.viewPath(), buildFiles...) + } + return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) +} + +func (c *Controller) viewPath() string { + if c.ViewPath == "" { + return BConfig.WebConfig.ViewsPath + } + return c.ViewPath +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + LogAccess(c.Ctx, nil, code) + c.Ctx.Redirect(code, url) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case context.ApplicationYAML: + c.Data["yaml"] = data + case context.ApplicationXML, context.TextXML: + c.Data["xml"] = data + default: + c.Data["json"] = data + } +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + status, err := strconv.Atoi(code) + if err != nil { + status = 200 + } + c.CustomAbort(status, code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + // first panic from ErrorMaps, it is user defined error functions. + if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status + panic(body) + } + // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) + c.Ctx.ResponseWriter.Write([]byte(body)) + panic(ErrAbort) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + panic(ErrAbort) +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + if len(endpoint) == 0 { + return "" + } + if endpoint[0] == '.' { + return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) + } + return URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + var ( + hasIndent = BConfig.RunMode != PROD + hasEncoding = len(encoding) > 0 && encoding[0] + ) + + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.XML(c.Data["xml"], hasIndent) +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + c.Ctx.Output.YAML(c.Data["yaml"]) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + hasIndent := BConfig.RunMode != PROD + hasEncoding := len(encoding) > 0 && encoding[0] + c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + if c.Ctx.Request.Form == nil { + c.Ctx.Request.ParseForm() + } + return c.Ctx.Request.Form +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return ParseForm(c.Input(), obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + if v := c.Ctx.Input.Query(key); v != "" { + return v + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + var defv []string + if len(def) > 0 { + defv = def[0] + } + + if f := c.Input(); f == nil { + return defv + } else if vs := f[key]; len(vs) > 0 { + return vs + } + + return defv +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.Atoi(strv) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 8) + return int8(i64), err +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 16) + return int16(i64), err +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 32) + return int32(i64), err +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseInt(strv, 10, 64) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseBool(strv) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseFloat(strv, 64) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return c.Ctx.Request.FormFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { + return files, nil + } + return nil, http.ErrMissingFile +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + file, _, err := c.Ctx.Request.FormFile(fromfile) + if err != nil { + return err + } + defer file.Close() + f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer f.Close() + io.Copy(f, file) + return nil +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + if c.CruSession == nil { + c.CruSession = c.Ctx.Input.CruSession + } + return c.CruSession +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Set(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + if c.CruSession == nil { + c.StartSession() + } + return c.CruSession.Get(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Delete(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + if c.CruSession != nil { + c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + } + c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) + c.Ctx.Input.CruSession = c.CruSession +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession = nil + GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return c.Ctx.Input.IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return c.Ctx.GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + c.Ctx.SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + if c._xsrfToken == "" { + expire := int64(BConfig.WebConfig.XSRFExpire) + if c.XSRFExpire > 0 { + expire = int64(c.XSRFExpire) + } + c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) + } + return c._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + if !c.EnableXSRF { + return true + } + return c.Ctx.CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return `` +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return c.controllerName, c.actionName +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go new file mode 100644 index 00000000..1e53416d --- /dev/null +++ b/pkg/controller_test.go @@ -0,0 +1,181 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "math" + "strconv" + "testing" + + "github.com/astaxie/beego/context" + "os" + "path/filepath" +) + +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt("age") + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } +} + +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt8("age") + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } + //Output: int8 +} + +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt16("age") + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt32("age") + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt64("age") + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } +} + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + dir1 := "_beeTmp" + dir2 := "_beeTmp2" + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} diff --git a/pkg/doc.go b/pkg/doc.go new file mode 100644 index 00000000..8825bd29 --- /dev/null +++ b/pkg/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego" + + func main() { + beego.Run() + } + +more information: http://beego.me +*/ +package beego diff --git a/pkg/error.go b/pkg/error.go new file mode 100644 index 00000000..f268f723 --- /dev/null +++ b/pkg/error.go @@ -0,0 +1,488 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "html/template" + "net/http" + "reflect" + "runtime" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +// render default application error page with error and stack string. +func showErr(err interface{}, ctx *context.Context, stack string) { + t, _ := template.New("beegoerrortemp").Parse(tpl) + data := map[string]string{ + "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), + "RequestMethod": ctx.Input.Method(), + "RequestURL": ctx.Input.URI(), + "RemoteAddr": ctx.Input.IP(), + "Stack": stack, + "BeegoVersion": VERSION, + "GoVersion": runtime.Version(), + } + t.Execute(ctx.ResponseWriter, data) +} + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +type errorInfo struct { + controllerType reflect.Type + handler http.HandlerFunc + method string + errorType int +} + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = make(map[string]*errorInfo, 10) + +// show 401 unauthorized error. +func unauthorized(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 402 Payment Required +func paymentRequired(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 403 forbidden error. +func forbidden(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    Your address may be blocked"+ + "
    The site may be disabled"+ + "
    You need to log in"+ + "
", + ) +} + +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    '_xsrf' argument missing from POST"+ + "
", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    expected XSRF not found"+ + "
", + ) +} + +// show 404 not found error. +func notFound(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The page has moved"+ + "
    The page no longer exists"+ + "
    You were looking for your puppy and got lost"+ + "
    You like 404 pages"+ + "
", + ) +} + +// show 405 Method Not Allowed +func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ + "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ + "
", + ) +} + +// show 500 internal server error. +func internalServerError(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 501 Not Implemented. +func notImplemented(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 502 Bad Gateway. +func badGateway(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

    "+ + "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 503 service unavailable error. +func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The page is overloaded"+ + "
    Please try again later."+ + "
", + ) +} + +// show 504 Gateway Timeout. +func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ + "
    Please try again later."+ + "
", + ) +} + +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { + t, _ := template.New("beegoerrortemp").Parse(errtpl) + data := M{ + "Title": http.StatusText(errCode), + "BeegoVersion": VERSION, + "Content": template.HTML(errContent), + } + t.Execute(rw, data) +} + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + ErrorMaps[code] = &errorInfo{ + errorType: errorTypeHandler, + handler: h, + method: code, + } + return BeeApp +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + for i := 0; i < rt.NumMethod(); i++ { + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { + errName := strings.TrimPrefix(methodName, "Error") + ErrorMaps[errName] = &errorInfo{ + errorType: errorTypeController, + controllerType: ct, + method: methodName, + } + } + } + return BeeApp +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + exception(strconv.FormatUint(errCode, 10), ctx) +} + +// show error string as simple text message. +// if error string is empty, show 503 or 500 error as default. +func exception(errCode string, ctx *context.Context) { + atoi := func(code string) int { + v, err := strconv.Atoi(code) + if err == nil { + return v + } + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status + } + + for _, ec := range []string{errCode, "503", "500"} { + if h, ok := ErrorMaps[ec]; ok { + executeError(h, ctx, atoi(ec)) + return + } + } + //if 50x error has been removed from errorMap + ctx.ResponseWriter.WriteHeader(atoi(errCode)) + ctx.WriteString(errCode) +} + +func executeError(err *errorInfo, ctx *context.Context, code int) { + //make sure to log the error in the access log + LogAccess(ctx, nil, code) + + if err.errorType == errorTypeHandler { + ctx.ResponseWriter.WriteHeader(code) + err.handler(ctx.ResponseWriter, ctx.Request) + return + } + if err.errorType == errorTypeController { + ctx.Output.SetStatus(code) + //Invoke the request handler + vc := reflect.New(err.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + //call the controller init function + execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) + + //call prepare function + execController.Prepare() + + execController.URLMapping() + + method := vc.MethodByName(err.method) + method.Call([]reflect.Value{}) + + //render template + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + panic(err) + } + } + + // finish all runrouter. release resource + execController.Finish() + } +} diff --git a/pkg/error_test.go b/pkg/error_test.go new file mode 100644 index 00000000..378aa953 --- /dev/null +++ b/pkg/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(w.Body.String(), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if w.Body.String() != parseCodeError { + t.Fail() + } +} diff --git a/pkg/filter.go b/pkg/filter.go new file mode 100644 index 00000000..9cc6e913 --- /dev/null +++ b/pkg/filter.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import "github.com/astaxie/beego/context" + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter struct { + filterFunc FilterFunc + tree *Tree + pattern string + returnOnOutput bool + resetParams bool +} + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } + } + return false +} diff --git a/pkg/filter_test.go b/pkg/filter_test.go new file mode 100644 index 00000000..4ca4d2b8 --- /dev/null +++ b/pkg/filter_test.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/context" +) + +var FilterUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) +} + +func TestFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) + handler.Add("/person/:last/:first", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am astaXie" { + t.Errorf("user define func can't run") + } +} + +var FilterAdminUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am admin")) +} + +// Filter pattern /admin/:all +// all url like /admin/ /admin/xie will all get filter + +func TestPatternTwo(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/ can't run") + } +} + +func TestPatternThree(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/astaxie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/astaxie can't run") + } +} diff --git a/pkg/flash.go b/pkg/flash.go new file mode 100644 index 00000000..a6485a17 --- /dev/null +++ b/pkg/flash.go @@ -0,0 +1,110 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "net/url" + "strings" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData struct { + Data map[string]string +} + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return &FlashData{ + Data: make(map[string]string), + } +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data[key] = msg + } else { + fd.Data[key] = fmt.Sprintf(msg, args...) + } +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["success"] = msg + } else { + fd.Data["success"] = fmt.Sprintf(msg, args...) + } +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["notice"] = msg + } else { + fd.Data["notice"] = fmt.Sprintf(msg, args...) + } +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["warning"] = msg + } else { + fd.Data["warning"] = fmt.Sprintf(msg, args...) + } +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["error"] = msg + } else { + fd.Data["error"] = fmt.Sprintf(msg, args...) + } +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + c.Data["flash"] = fd.Data + var flashValue string + for key, value := range fd.Data { + flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" + } + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + flash := NewFlash() + if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { + v, _ := url.QueryUnescape(cookie.Value) + vals := strings.Split(v, "\x00") + for _, v := range vals { + if len(v) > 0 { + kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") + if len(kv) == 2 { + flash.Data[kv[0]] = kv[1] + } + } + } + //read one time then delete it + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") + } + c.Data["flash"] = flash.Data + return flash +} diff --git a/pkg/flash_test.go b/pkg/flash_test.go new file mode 100644 index 00000000..d5e9608d --- /dev/null +++ b/pkg/flash_test.go @@ -0,0 +1,54 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type TestFlashController struct { + Controller +} + +func (t *TestFlashController) TestWriteFlash() { + flash := NewFlash() + flash.Notice("TestFlashString") + flash.Store(&t.Controller) + // we choose to serve json because we don't want to load a template html file + t.ServeJSON(true) +} + +func TestFlashHeader(t *testing.T) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.ServeHTTP(w, r) + + // get the Set-Cookie value + sc := w.Header().Get("Set-Cookie") + // match for the expected header + res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") + // validate the assertion + if !res { + t.Errorf("TestFlashHeader() unable to validate flash message") + } +} diff --git a/pkg/fs.go b/pkg/fs.go new file mode 100644 index 00000000..41cc6f6e --- /dev/null +++ b/pkg/fs.go @@ -0,0 +1,74 @@ +package beego + +import ( + "net/http" + "os" + "path/filepath" +) + +type FileSystem struct { +} + +func (d FileSystem) Open(name string) (http.File, error) { + return os.Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + + f, err := fs.Open(root) + if err != nil { + return err + } + info, err := f.Stat() + if err != nil { + err = walkFn(root, nil, err) + } else { + err = walk(fs, root, info, walkFn) + } + if err == filepath.SkipDir { + return nil + } + return err +} + +// walk recursively descends path, calling walkFn. +func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { + var err error + if !info.IsDir() { + return walkFn(path, info, nil) + } + + dir, err := fs.Open(path) + if err != nil { + if err1 := walkFn(path, info, err); err1 != nil { + return err1 + } + return err + } + defer dir.Close() + dirs, err := dir.Readdir(-1) + err1 := walkFn(path, info, err) + // If err != nil, walk can't walk into this directory. + // err1 != nil means walkFn want walk to skip this directory or stop walking. + // Therefore, if one of err and err1 isn't nil, walk will return. + if err != nil || err1 != nil { + // The caller's behavior is controlled by the return value, which is decided + // by walkFn. walkFn may ignore err and return nil. + // If walkFn returns SkipDir, it will be handled by the caller. + // So walk should return whatever walkFn returns. + return err1 + } + + for _, fileInfo := range dirs { + filename := filepath.Join(path, fileInfo.Name()) + if err = walk(fs, filename, fileInfo, walkFn); err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + return nil +} diff --git a/pkg/grace/grace.go b/pkg/grace/grace.go new file mode 100644 index 00000000..fb0cb7bb --- /dev/null +++ b/pkg/grace/grace.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "flag" + "net/http" + "os" + "strings" + "sync" + "syscall" + "time" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + regLock *sync.Mutex + runningServers map[string]*Server + runningServersOrder []string + socketPtrOffsetMap map[string]uint + runningServersForked bool + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = 60 * time.Second + + isChild bool + socketOrder string + + hookableSignals []os.Signal +) + +func init() { + flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") + flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + + regLock = &sync.Mutex{} + runningServers = make(map[string]*Server) + runningServersOrder = []string{} + socketPtrOffsetMap = make(map[string]uint) + + hookableSignals = []os.Signal{ + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + } +} + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + regLock.Lock() + defer regLock.Unlock() + + if !flag.Parsed() { + flag.Parse() + } + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + } + + srv = &Server{ + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PreSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + PostSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + }, + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, + } + + runningServersOrder = append(runningServersOrder, addr) + runningServers[addr] = srv + return srv +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/grace/server.go b/pkg/grace/server.go new file mode 100644 index 00000000..008a6171 --- /dev/null +++ b/pkg/grace/server.go @@ -0,0 +1,356 @@ +package grace + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" +) + +// Server embedded http.Server +type Server struct { + *http.Server + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error +} + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + srv.state = StateRunning + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + if shutdownErr := <-srv.terminalChan; shutdownErr != nil { + return shutdownErr + } + return +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + + go srv.handleSignals() + + srv.ln, err = srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(trustFile) + if err != nil { + log.Println(err) + return err + } + pool.AppendCertsFromPEM(data) + srv.TLSConfig.ClientCAs = pool + log.Println("Mutual HTTPS") + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// getListener either opens a new socket to listen on, or takes the acceptor socket +// it got passed when restarted. +func (srv *Server) getListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen(srv.Network, laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// handleSignals listens for os Signals and calls any hooked in function that the +// user had registered with the signal. +func (srv *Server) handleSignals() { + var sig os.Signal + + signal.Notify( + srv.sigChan, + hookableSignals..., + ) + + pid := syscall.Getpid() + for { + sig = <-srv.sigChan + srv.signalHooks(PreSignal, sig) + switch sig { + case syscall.SIGHUP: + log.Println(pid, "Received SIGHUP. forking.") + err := srv.fork() + if err != nil { + log.Println("Fork err:", err) + } + case syscall.SIGINT: + log.Println(pid, "Received SIGINT.") + srv.shutdown() + case syscall.SIGTERM: + log.Println(pid, "Received SIGTERM.") + srv.shutdown() + default: + log.Printf("Received %v: nothing i care about...\n", sig) + } + srv.signalHooks(PostSignal, sig) + } +} + +func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { + if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { + return + } + for _, f := range srv.SignalHooks[ppFlag][sig] { + f() + } +} + +// shutdown closes the listener so that no new connections are accepted. it also +// starts a goroutine that will serverTimeout (stop all running requests) the server +// after DefaultTimeout. +func (srv *Server) shutdown() { + if srv.state != StateRunning { + return + } + + srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() + if DefaultTimeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + } + srv.terminalChan <- srv.Server.Shutdown(ctx) +} + +func (srv *Server) fork() (err error) { + regLock.Lock() + defer regLock.Unlock() + if runningServersForked { + return + } + runningServersForked = true + + var files = make([]*os.File, len(runningServers)) + var orderArgs = make([]string, len(runningServers)) + for _, srvPtr := range runningServers { + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + log.Println(files) + path := os.Args[0] + var args []string + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-graceful" { + break + } + args = append(args, arg) + } + } + args = append(args, "-graceful") + if len(runningServers) > 1 { + args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + log.Println(args) + } + cmd := exec.Command(path, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + err = cmd.Start() + if err != nil { + log.Fatalf("Restart: Failed to launch, error: %v", err) + } + + return +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { + if ppFlag != PreSignal && ppFlag != PostSignal { + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") + return + } + for _, s := range hookableSignals { + if s == sig { + srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) + return + } + } + err = fmt.Errorf("Signal '%v' is not supported", sig) + return +} diff --git a/pkg/hooks.go b/pkg/hooks.go new file mode 100644 index 00000000..49c42d5a --- /dev/null +++ b/pkg/hooks.go @@ -0,0 +1,104 @@ +package beego + +import ( + "encoding/json" + "mime" + "net/http" + "path/filepath" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" +) + +// register MIME type with content type +func registerMime() error { + for k, v := range mimemaps { + mime.AddExtensionType(k, v) + } + return nil +} + +// register default error http handlers, 404,401,403,500 and 503. +func registerDefaultErrorHandler() error { + m := map[string]func(http.ResponseWriter, *http.Request){ + "401": unauthorized, + "402": paymentRequired, + "403": forbidden, + "404": notFound, + "405": methodNotAllowed, + "500": internalServerError, + "501": notImplemented, + "502": badGateway, + "503": serviceUnavailable, + "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, + "413": payloadTooLarge, + } + for e, h := range m { + if _, ok := ErrorMaps[e]; !ok { + ErrorHandler(e, h) + } + } + return nil +} + +func registerSession() error { + if BConfig.WebConfig.Session.SessionOn { + var err error + sessionConfig := AppConfig.String("sessionConfig") + conf := new(session.ManagerConfig) + if sessionConfig == "" { + conf.CookieName = BConfig.WebConfig.Session.SessionName + conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie + conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime + conf.Secure = BConfig.Listen.EnableHTTPS + conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime + conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly + conf.Domain = BConfig.WebConfig.Session.SessionDomain + conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + } else { + if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { + return err + } + } + if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { + return err + } + go GlobalSessions.GC() + } + return nil +} + +func registerTemplate() error { + defer lockViewPaths() + if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { + if BConfig.RunMode == DEV { + logs.Warn(err) + } + return err + } + return nil +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil +} + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/pkg/httplib/README.md b/pkg/httplib/README.md new file mode 100644 index 00000000..97df8e6b --- /dev/null +++ b/pkg/httplib/README.md @@ -0,0 +1,97 @@ +# httplib +httplib is an libs help you to curl remote url. + +# How to use? + +## GET +you can use Get to crawl data. + + import "github.com/astaxie/beego/httplib" + + str, err := httplib.Get("http://beego.me/").String() + if err != nil { + // error + } + fmt.Println(str) + +## POST +POST data to remote url + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.Param("password","123456") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + +## Set timeout + +The default timeout is `60` seconds, function prototype: + + SetTimeout(connectTimeout, readWriteTimeout time.Duration) + +Example: + + // GET + httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + // POST + httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + +## Debug + +If you want to debug the request info, set the debug on + + httplib.Get("http://beego.me/").Debug(true) + +## Set HTTP Basic Auth + + str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() + if err != nil { + // error + } + fmt.Println(str) + +## Set HTTPS + +If request url is https, You can set the client support TSL: + + httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) + +More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config + +## Set HTTP Version + +some servers need to specify the protocol version of HTTP + + httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") + +## Set Cookie + +some http request need setcookie. So set it like this: + + cookie := &http.Cookie{} + cookie.Name = "username" + cookie.Value = "astaxie" + httplib.Get("http://beego.me/").SetCookie(cookie) + +## Upload file + +httplib support mutil file upload, use `req.PostFile()` + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.PostFile("uploadfile1", "httplib.pdf") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + + +See godoc for further documentation and examples. + +* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go new file mode 100644 index 00000000..60aa4e8b --- /dev/null +++ b/pkg/httplib/httplib.go @@ -0,0 +1,654 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httputil" + "net/url" + "os" + "path" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v2" +) + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + DumpBody: true, +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + var resp http.Response + u, err := url.Parse(rawurl) + if err != nil { + log.Println("Httplib:", err) + } + req := http.Request{ + URL: u, + Method: method, + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + return &BeegoHTTPRequest{ + url: rawurl, + req: &req, + params: map[string][]string{}, + files: map[string]string{}, + setting: defaultSetting, + resp: &resp, + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + DumpBody bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration +} + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + url string + req *http.Request + params map[string][]string + files map[string]string + setting BeegoHTTPSettings + resp *http.Response + body []byte + dump []byte +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.req +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.setting = setting + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.req.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.setting.EnableCookie = enable + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.setting.UserAgent = useragent + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.setting.ShowDebug = isdebug + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.setting.Retries = times + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.setting.DumpBody = isdump + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.dump +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.setting.ConnectTimeout = connectTimeout + b.setting.ReadWriteTimeout = readWriteTimeout + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.setting.TLSClientConfig = config + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.req.Header.Set(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.req.Host = host + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + if len(vers) == 0 { + vers = "HTTP/1.1" + } + + major, minor, ok := http.ParseHTTPVersion(vers) + if ok { + b.req.Proto = vers + b.req.ProtoMajor = major + b.req.ProtoMinor = minor + } + + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.req.Header.Add("Cookie", cookie.String()) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.setting.Transport = transport + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.setting.Proxy = proxy + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + if param, ok := b.params[key]; ok { + b.params[key] = append(param, value) + } else { + b.params[key] = []string{value} + } + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.files[formname] = filename + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + switch t := data.(type) { + case string: + bf := bytes.NewBufferString(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + case []byte: + bf := bytes.NewBuffer(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + } + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := xml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/xml") + } + return b, nil +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := yaml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/x+yaml") + } + return b, nil +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := json.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/json") + } + return b, nil +} + +func (b *BeegoHTTPRequest) buildURL(paramBody string) { + // build GET url with query string + if b.req.Method == "GET" && len(paramBody) > 0 { + if strings.Contains(b.url, "?") { + b.url += "&" + paramBody + } else { + b.url = b.url + "?" + paramBody + } + return + } + + // build POST/PUT/PATCH url and body + if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { + // with files + if len(b.files) > 0 { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + if err != nil { + log.Println("Httplib:", err) + } + fh, err := os.Open(filename) + if err != nil { + log.Println("Httplib:", err) + } + //iocopy + _, err = io.Copy(fileWriter, fh) + fh.Close() + if err != nil { + log.Println("Httplib:", err) + } + } + for k, v := range b.params { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + b.Header("Content-Type", bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") + return + } + + // with params + if len(paramBody) > 0 { + b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Body(paramBody) + } + } +} + +func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { + if b.resp.StatusCode != 0 { + return b.resp, nil + } + resp, err := b.DoRequest() + if err != nil { + return nil, err + } + b.resp = resp + return resp, nil +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + + b.buildURL(paramBody) + urlParsed, err := url.Parse(b.url) + if err != nil { + return nil, err + } + + b.req.URL = urlParsed + + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.Dial == nil { + t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + + client := &http.Client{ + Transport: trans, + Jar: jar, + } + + if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { + b.req.Header.Set("User-Agent", b.setting.UserAgent) + } + + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + + if b.setting.ShowDebug { + dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) + if err != nil { + log.Println(err.Error()) + } + b.dump = dump + } + // retries default value is 0, it will run once. + // retries equal to -1, it will run forever until success + // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam + for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { + resp, err = client.Do(b.req) + if err == nil { + break + } + time.Sleep(b.setting.RetryDelay) + } + return resp, err +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + data, err := b.Bytes() + if err != nil { + return "", err + } + + return string(data), nil +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + if b.body != nil { + return b.body, nil + } + resp, err := b.getResponse() + if err != nil { + return nil, err + } + if resp.Body == nil { + return nil, nil + } + defer resp.Body.Close() + if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + b.body, err = ioutil.ReadAll(reader) + return b.body, err + } + b.body, err = ioutil.ReadAll(resp.Body) + return b.body, err +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + resp, err := b.getResponse() + if err != nil { + return err + } + if resp.Body == nil { + return nil + } + defer resp.Body.Close() + err = pathExistAndMkdir(filename) + if err != nil { + return err + } + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +//Check that the file directory exists, there is no automatically created +func pathExistAndMkdir(filename string) (err error) { + filename = path.Dir(filename) + _, err = os.Stat(filename) + if err == nil { + return nil + } + if os.IsNotExist(err) { + err = os.MkdirAll(filename, os.ModePerm) + if err == nil { + return nil + } + } + return err +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(data, v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return yaml.Unmarshal(data, v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.getResponse() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return func(netw, addr string) (net.Conn, error) { + conn, err := net.DialTimeout(netw, addr, cTimeout) + if err != nil { + return nil, err + } + err = conn.SetDeadline(time.Now().Add(rwTimeout)) + return conn, err + } +} diff --git a/pkg/httplib/httplib_test.go b/pkg/httplib/httplib_test.go new file mode 100644 index 00000000..f6be8571 --- /dev/null +++ b/pkg/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +//func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +//} + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/log.go b/pkg/log.go new file mode 100644 index 00000000..cc4c0f81 --- /dev/null +++ b/pkg/log.go @@ -0,0 +1,127 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/logs/README.md b/pkg/logs/README.md new file mode 100644 index 00000000..c05bcc04 --- /dev/null +++ b/pkg/logs/README.md @@ -0,0 +1,72 @@ +## logs +logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/logs + + +## What adapters are supported? + +As of now this logs support console, file,smtp and conn. + + +## How to use it? + +First you must import it + +```golang +import ( + "github.com/astaxie/beego/logs" +) +``` + +Then init a Log (example with console adapter) + +```golang +log := logs.NewLogger(10000) +log.SetLogger("console", "") +``` + +> the first params stand for how many channel + +Use it like this: + +```golang +log.Trace("trace") +log.Info("info") +log.Warn("warning") +log.Debug("debug") +log.Critical("critical") +``` + +## File adapter + +Configure file adapter like this: + +```golang +log := NewLogger(10000) +log.SetLogger("file", `{"filename":"test.log"}`) +``` + +## Conn adapter + +Configure like this: + +```golang +log := NewLogger(1000) +log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) +log.Info("info") +``` + +## Smtp adapter + +Configure like this: + +```golang +log := NewLogger(10000) +log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) +log.Critical("sendmail critical") +time.Sleep(time.Second * 30) +``` diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go new file mode 100644 index 00000000..3ff9e20f --- /dev/null +++ b/pkg/logs/accesslog.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "strings" + "encoding/json" + "fmt" + "time" +) + +const ( + apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" + apacheFormat = "APACHE_FORMAT" + jsonFormat = "JSON_FORMAT" +) + +// AccessLogRecord struct for holding access log data. +type AccessLogRecord struct { + RemoteAddr string `json:"remote_addr"` + RequestTime time.Time `json:"request_time"` + RequestMethod string `json:"request_method"` + Request string `json:"request"` + ServerProtocol string `json:"server_protocol"` + Host string `json:"host"` + Status int `json:"status"` + BodyBytesSent int64 `json:"body_bytes_sent"` + ElapsedTime time.Duration `json:"elapsed_time"` + HTTPReferrer string `json:"http_referrer"` + HTTPUserAgent string `json:"http_user_agent"` + RemoteUser string `json:"remote_user"` +} + +func (r *AccessLogRecord) json() ([]byte, error) { + buffer := &bytes.Buffer{} + encoder := json.NewEncoder(buffer) + disableEscapeHTML(encoder) + + err := encoder.Encode(r) + return buffer.Bytes(), err +} + +func disableEscapeHTML(i interface{}) { + if e, ok := i.(interface { + SetEscapeHTML(bool) + }); ok { + e.SetEscapeHTML(false) + } +} + +// AccessLog - Format and print access log. +func AccessLog(r *AccessLogRecord, format string) { + var msg string + switch format { + case apacheFormat: + timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") + msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, + r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) + case jsonFormat: + fallthrough + default: + jsonData, err := r.json() + if err != nil { + msg = fmt.Sprintf(`{"Error": "%s"}`, err) + } else { + msg = string(jsonData) + } + } + beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) +} diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go new file mode 100644 index 00000000..867ff4cb --- /dev/null +++ b/pkg/logs/alils/alils.go @@ -0,0 +1,186 @@ +package alils + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" +) + +const ( + // CacheSize set the flush size + CacheSize int = 64 + // Delimiter define the topic delimiter + Delimiter string = "##" +) + +// Config is the Config for Ali Log +type Config struct { + Project string `json:"project"` + Endpoint string `json:"endpoint"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` + LogStore string `json:"log_store"` + Topics []string `json:"topics"` + Source string `json:"source"` + Level int `json:"level"` + FlushWhen int `json:"flush_when"` +} + +// aliLSWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type aliLSWriter struct { + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + Config +} + +// NewAliLS create a new Logger +func NewAliLS() logs.Logger { + alils := new(aliLSWriter) + alils.Level = logs.LevelTrace + return alils +} + +// Init parse config and init struct +func (c *aliLSWriter) Init(jsonConfig string) (err error) { + + json.Unmarshal([]byte(jsonConfig), c) + + if c.FlushWhen > CacheSize { + c.FlushWhen = CacheSize + } + + prj := &LogProject{ + Name: c.Project, + Endpoint: c.Endpoint, + AccessKeyID: c.KeyID, + AccessKeySecret: c.KeySecret, + } + + c.store, err = prj.GetLogStore(c.LogStore) + if err != nil { + return err + } + + // Create default Log Group + c.group = append(c.group, &LogGroup{ + Topic: proto.String(""), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + }) + + // Create other Log Group + c.groupMap = make(map[string]*LogGroup) + for _, topic := range c.Topics { + + lg := &LogGroup{ + Topic: proto.String(topic), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + } + + c.group = append(c.group, lg) + c.groupMap[topic] = lg + } + + if len(c.group) == 1 { + c.withMap = false + } else { + c.withMap = true + } + + c.lock = &sync.Mutex{} + + return nil +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { + + if level > c.Level { + return nil + } + + var topic string + var content string + var lg *LogGroup + if c.withMap { + + // Topic,LogGroup + strs := strings.SplitN(msg, Delimiter, 2) + if len(strs) == 2 { + pos := strings.LastIndex(strs[0], " ") + topic = strs[0][pos+1 : len(strs[0])] + content = strs[0][0:pos] + strs[1] + lg = c.groupMap[topic] + } + + // send to empty Topic + if lg == nil { + content = msg + lg = c.group[0] + } + } else { + content = msg + lg = c.group[0] + } + + c1 := &LogContent{ + Key: proto.String("msg"), + Value: proto.String(content), + } + + l := &Log{ + Time: proto.Uint32(uint32(when.Unix())), + Contents: []*LogContent{ + c1, + }, + } + + c.lock.Lock() + lg.Logs = append(lg.Logs, l) + c.lock.Unlock() + + if len(lg.Logs) >= c.FlushWhen { + c.flush(lg) + } + + return nil +} + +// Flush implementing method. empty. +func (c *aliLSWriter) Flush() { + + // flush all group + for _, lg := range c.group { + c.flush(lg) + } +} + +// Destroy destroy connection writer and close tcp listener. +func (c *aliLSWriter) Destroy() { +} + +func (c *aliLSWriter) flush(lg *LogGroup) { + + c.lock.Lock() + defer c.lock.Unlock() + err := c.store.PutLogs(lg) + if err != nil { + return + } + + lg.Logs = make([]*Log, 0, c.FlushWhen) +} + +func init() { + logs.Register(logs.AdapterAliLS, NewAliLS) +} diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go new file mode 100755 index 00000000..e8c24448 --- /dev/null +++ b/pkg/logs/alils/config.go @@ -0,0 +1,13 @@ +package alils + +const ( + version = "0.5.0" // SDK version + signatureMethod = "hmac-sha1" // Signature method + + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the shard. + OffsetNewest = "end" + // OffsetOldest stands for the oldest offset available on the logstore for a + // shard. + OffsetOldest = "begin" +) diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go new file mode 100755 index 00000000..601b0d78 --- /dev/null +++ b/pkg/logs/alils/log.pb.go @@ -0,0 +1,1038 @@ +package alils + +import ( + "fmt" + "io" + "math" + + "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +var ( + // ErrInvalidLengthLog invalid proto + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + // ErrIntOverflowLog overflow + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) + +// Log define the proto Log +type Log struct { + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset the Log +func (m *Log) Reset() { *m = Log{} } + +// String return the Compact Log +func (m *Log) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*Log) ProtoMessage() {} + +// GetTime return the Log's Time +func (m *Log) GetTime() uint32 { + if m != nil && m.Time != nil { + return *m.Time + } + return 0 +} + +// GetContents return the Log's Contents +func (m *Log) GetContents() []*LogContent { + if m != nil { + return m.Contents + } + return nil +} + +// LogContent define the Log content struct +type LogContent struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogContent +func (m *LogContent) Reset() { *m = LogContent{} } + +// String return the compact text +func (m *LogContent) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogContent) ProtoMessage() {} + +// GetKey return the Key +func (m *LogContent) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +// GetValue return the Value +func (m *LogContent) GetValue() string { + if m != nil && m.Value != nil { + return *m.Value + } + return "" +} + +// LogGroup define the logs struct +type LogGroup struct { + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroup +func (m *LogGroup) Reset() { *m = LogGroup{} } + +// String return the compact text +func (m *LogGroup) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroup) ProtoMessage() {} + +// GetLogs return the loggroup logs +func (m *LogGroup) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +// GetReserved return Reserved +func (m *LogGroup) GetReserved() string { + if m != nil && m.Reserved != nil { + return *m.Reserved + } + return "" +} + +// GetTopic return Topic +func (m *LogGroup) GetTopic() string { + if m != nil && m.Topic != nil { + return *m.Topic + } + return "" +} + +// GetSource return Source +func (m *LogGroup) GetSource() string { + if m != nil && m.Source != nil { + return *m.Source + } + return "" +} + +// LogGroupList define the LogGroups +type LogGroupList struct { + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroupList +func (m *LogGroupList) Reset() { *m = LogGroupList{} } + +// String return compact text +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroupList) ProtoMessage() {} + +// GetLogGroups return the LogGroups +func (m *LogGroupList) GetLogGroups() []*LogGroup { + if m != nil { + return m.LogGroups + } + return nil +} + +// Marshal the logs to byte slice +func (m *Log) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo data +func (m *Log) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Time == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) + if len(m.Contents) > 0 { + for _, msg := range m.Contents { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogContent +func (m *LogContent) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo logcontent to data +func (m *LogContent) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Key == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + + if m.Value == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroup +func (m *LogGroup) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroup to data +func (m *LogGroup) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Logs) > 0 { + for _, msg := range m.Logs { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.Reserved != nil { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) + i += copy(data[i:], *m.Reserved) + } + if m.Topic != nil { + data[i] = 0x1a + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Topic))) + i += copy(data[i:], *m.Topic) + } + if m.Source != nil { + data[i] = 0x22 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Source))) + i += copy(data[i:], *m.Source) + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroupList +func (m *LogGroupList) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroupList to data +func (m *LogGroupList) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, msg := range m.LogGroups { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +func encodeFixed64Log(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Log(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintLog(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} + +// Size return the log's size +func (m *Log) Size() (n int) { + var l int + _ = l + if m.Time != nil { + n += 1 + sovLog(uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, e := range m.Contents { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogContent size based on Key and Value +func (m *LogContent) Size() (n int) { + var l int + _ = l + if m.Key != nil { + l = len(*m.Key) + n += 1 + l + sovLog(uint64(l)) + } + if m.Value != nil { + l = len(*m.Value) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroup size based on Logs +func (m *LogGroup) Size() (n int) { + var l int + _ = l + if len(m.Logs) > 0 { + for _, e := range m.Logs { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.Reserved != nil { + l = len(*m.Reserved) + n += 1 + l + sovLog(uint64(l)) + } + if m.Topic != nil { + l = len(*m.Topic) + n += 1 + l + sovLog(uint64(l)) + } + if m.Source != nil { + l = len(*m.Source) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroupList size +func (m *LogGroupList) Size() (n int) { + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, e := range m.LogGroups { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +func sovLog(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozLog(x uint64) (n int) { + return sovLog((x << 1) ^ (x >> 63)) +} + +// Unmarshal data to log +func (m *Log) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Log: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) + } + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + v |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Time = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Contents = append(m.Contents, &LogContent{}) + if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogContent +func (m *LogContent) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Content: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Key = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Value = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroup +func (m *LogGroup) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Logs = append(m.Logs, &Log{}) + if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Reserved = &s + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Topic = &s + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Source = &s + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroupList +func (m *LogGroupList) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.LogGroups = append(m.LogGroups, &LogGroup{}) + if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipLog(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthLog + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipLog(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go new file mode 100755 index 00000000..e8564efb --- /dev/null +++ b/pkg/logs/alils/log_config.go @@ -0,0 +1,42 @@ +package alils + +// InputDetail define log detail +type InputDetail struct { + LogType string `json:"logType"` + LogPath string `json:"logPath"` + FilePattern string `json:"filePattern"` + LocalStorage bool `json:"localStorage"` + TimeFormat string `json:"timeFormat"` + LogBeginRegex string `json:"logBeginRegex"` + Regex string `json:"regex"` + Keys []string `json:"key"` + FilterKeys []string `json:"filterKey"` + FilterRegex []string `json:"filterRegex"` + TopicFormat string `json:"topicFormat"` +} + +// OutputDetail define the output detail +type OutputDetail struct { + Endpoint string `json:"endpoint"` + LogStoreName string `json:"logstoreName"` +} + +// LogConfig define Log Config +type LogConfig struct { + Name string `json:"configName"` + InputType string `json:"inputType"` + InputDetail InputDetail `json:"inputDetail"` + OutputType string `json:"outputType"` + OutputDetail OutputDetail `json:"outputDetail"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// GetAppliedMachineGroup returns applied machine group of this config. +func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { + groupNames, err = c.project.GetAppliedMachineGroups(c.Name) + return +} diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go new file mode 100755 index 00000000..59db8cbf --- /dev/null +++ b/pkg/logs/alils/log_project.go @@ -0,0 +1,819 @@ +/* +Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). + +For more description about SLS, please read this article: +http://gitlab.alibaba-inc.com/sls/doc. +*/ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// Error message in SLS HTTP response. +type errorMessage struct { + Code string `json:"errorCode"` + Message string `json:"errorMessage"` +} + +// LogProject Define the Ali Project detail +type LogProject struct { + Name string // Project name + Endpoint string // IP or hostname of SLS endpoint + AccessKeyID string + AccessKeySecret string +} + +// NewLogProject creates a new SLS project. +func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { + p = &LogProject{ + Name: name, + Endpoint: endpoint, + AccessKeyID: AccessKeyID, + AccessKeySecret: accessKeySecret, + } + return p, nil +} + +// ListLogStore returns all logstore names of project p. +func (p *LogProject) ListLogStore() (storeNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores") + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + LogStores []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + storeNames = body.LogStores + + return +} + +// GetLogStore returns logstore according by logstore name. +func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/logstores/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + s = &LogStore{} + err = json.Unmarshal(buf, s) + if err != nil { + return + } + s.project = p + return +} + +// CreateLogStore creates a new logstore in SLS, +// where name is logstore name, +// and ttl is time-to-live(in day) of logs, +// and shardCnt is the number of shards. +func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteLogStore deletes a logstore according by logstore name. +func (p *LogProject) DeleteLogStore(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/logstores/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// UpdateLogStore updates a logstore according by logstore name, +// obviously we can't modify the logstore name itself. +func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// ListMachineGroup returns machine group name list and the total number of machine groups. +// The offset starts from 0 and the size is the max number of machine groups could be returned. +func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 500 + } + + uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + MachineGroups []string + Count int + Total int + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + m = body.MachineGroups + total = body.Total + + return +} + +// GetMachineGroup retruns machine group according by machine group name. +func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get machine group:%v", name) + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + m = &MachineGroup{} + err = json.Unmarshal(buf, m) + if err != nil { + return + } + m.project = p + return +} + +// CreateMachineGroup creates a new machine group in SLS. +func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/machinegroups", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// UpdateMachineGroup updates a machine group. +func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteMachineGroup deletes machine group according machine group name. +func (p *LogProject) DeleteMachineGroup(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// ListConfig returns config names list and the total number of configs. +// The offset starts from 0 and the size is the max number of configs could be returned. +func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 100 + } + + uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Total int + Configs []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + cfgNames = body.Configs + total = body.Total + return +} + +// GetConfig returns config according by config name. +func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/configs/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + c = &LogConfig{} + err = json.Unmarshal(buf, c) + if err != nil { + return + } + c.project = p + return +} + +// UpdateConfig updates a config. +func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/configs/"+c.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// CreateConfig creates a new config in SLS. +func (p *LogProject) CreateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/configs", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteConfig deletes a config according by config name. +func (p *LogProject) DeleteConfig(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/configs/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetAppliedMachineGroups returns applied machine group names list according config name. +func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/configs/%v/machinegroups", confName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get applied machine groups") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + Machinegroups []string + } + + body := &Body{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + groupNames = body.Machinegroups + return +} + +// GetAppliedConfigs returns applied config names list according machine group name groupName. +func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to applied configs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Cfg struct { + Count int `json:"count"` + Configs []string `json:"configs"` + } + + body := &Cfg{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + confNames = body.Configs + return +} + +// ApplyConfigToMachineGroup applies config to machine group. +func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "PUT", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to apply config to machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// RemoveConfigFromMachineGroup removes config from machine group. +func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "DELETE", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go new file mode 100755 index 00000000..fa502736 --- /dev/null +++ b/pkg/logs/alils/log_store.go @@ -0,0 +1,271 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strconv" + + lz4 "github.com/cloudflare/golz4" + "github.com/gogo/protobuf/proto" +) + +// LogStore Store the logs +type LogStore struct { + Name string `json:"logstoreName"` + TTL int + ShardCount int + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Shard define the Log Shard +type Shard struct { + ShardID int `json:"shardID"` +} + +// ListShards returns shard id list of this logstore. +func (s *LogStore) ListShards() (shardIDs []int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards", s.Name) + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + var shards []*Shard + err = json.Unmarshal(buf, &shards) + if err != nil { + return + } + + for _, v := range shards { + shardIDs = append(shardIDs, v.ShardID) + } + return +} + +// PutLogs put logs into logstore. +// The callers should transform user logs into LogGroup. +func (s *LogStore) PutLogs(lg *LogGroup) (err error) { + body, err := proto.Marshal(lg) + if err != nil { + return + } + + // Compresse body with lz4 + out := make([]byte, lz4.CompressBound(body)) + n, err := lz4.Compress(body, out) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-compresstype": "lz4", + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/x-protobuf", + } + + uri := fmt.Sprintf("/logstores/%v", s.Name) + r, err := request(s.project, "POST", uri, h, out[:n]) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to put logs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetCursor gets log cursor of one shard specified by shardID. +// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". +// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore +func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", + s.Name, shardID, from) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Cursor string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + cursor = body.Cursor + return +} + +// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogsBytes(shardID int, cursor string, + logGroupMaxCount int) (out []byte, nextCursor string, err error) { + + h := map[string]string{ + "x-sls-bodyrawsize": "0", + "Accept": "application/x-protobuf", + "Accept-Encoding": "lz4", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", + s.Name, shardID, cursor, logGroupMaxCount) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + v, ok := r.Header["X-Sls-Compresstype"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-compresstype' header") + return + } + if v[0] != "lz4" { + err = fmt.Errorf("unexpected compress type:%v", v[0]) + return + } + + v, ok = r.Header["X-Sls-Cursor"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-cursor' header") + return + } + nextCursor = v[0] + + v, ok = r.Header["X-Sls-Bodyrawsize"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") + return + } + bodyRawSize, err := strconv.Atoi(v[0]) + if err != nil { + return + } + + out = make([]byte, bodyRawSize) + err = lz4.Uncompress(buf, out) + if err != nil { + return + } + + return +} + +// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { + + gl = &LogGroupList{} + err = proto.Unmarshal(data, gl) + if err != nil { + return + } + + return +} + +// GetLogs gets logs from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogs(shardID int, cursor string, + logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { + + out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) + if err != nil { + return + } + + gl, err = LogsBytesDecode(out) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go new file mode 100755 index 00000000..b6c69a14 --- /dev/null +++ b/pkg/logs/alils/machine_group.go @@ -0,0 +1,91 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// MachineGroupAttribute define the Attribute +type MachineGroupAttribute struct { + ExternalName string `json:"externalName"` + TopicName string `json:"groupTopic"` +} + +// MachineGroup define the machine Group +type MachineGroup struct { + Name string `json:"groupName"` + Type string `json:"groupType"` + MachineIDType string `json:"machineIdentifyType"` + MachineIDList []string `json:"machineList"` + + Attribute MachineGroupAttribute `json:"groupAttribute"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Machine define the Machine +type Machine struct { + IP string + UniqueID string `json:"machine-uniqueid"` + UserdefinedID string `json:"userdefined-id"` +} + +// MachineList define the Machine List +type MachineList struct { + Total int + Machines []*Machine +} + +// ListMachines returns machine list of this machine group. +func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) + r, err := request(m.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + body := &MachineList{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + ms = body.Machines + total = body.Total + + return +} + +// GetAppliedConfigs returns applied configs of this machine group. +func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { + confNames, err = m.project.GetAppliedConfigs(m.Name) + return +} diff --git a/pkg/logs/alils/request.go b/pkg/logs/alils/request.go new file mode 100755 index 00000000..50d9c43c --- /dev/null +++ b/pkg/logs/alils/request.go @@ -0,0 +1,62 @@ +package alils + +import ( + "bytes" + "crypto/md5" + "fmt" + "net/http" +) + +// request sends a request to SLS. +func request(project *LogProject, method, uri string, headers map[string]string, + body []byte) (resp *http.Response, err error) { + + // The caller should provide 'x-sls-bodyrawsize' header + if _, ok := headers["x-sls-bodyrawsize"]; !ok { + err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + return + } + + // SLS public request headers + headers["Host"] = project.Name + "." + project.Endpoint + headers["Date"] = nowRFC1123() + headers["x-sls-apiversion"] = version + headers["x-sls-signaturemethod"] = signatureMethod + if body != nil { + bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) + headers["Content-MD5"] = bodyMD5 + + if _, ok := headers["Content-Type"]; !ok { + err = fmt.Errorf("Can't find 'Content-Type' header") + return + } + } + + // Calc Authorization + // Authorization = "SLS :" + digest, err := signature(project, method, uri, headers) + if err != nil { + return + } + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) + headers["Authorization"] = auth + + // Initialize http request + reader := bytes.NewReader(body) + urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) + req, err := http.NewRequest(method, urlStr, reader) + if err != nil { + return + } + for k, v := range headers { + req.Header.Add(k, v) + } + + // Get ready to do request + resp, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/signature.go b/pkg/logs/alils/signature.go new file mode 100755 index 00000000..2d611307 --- /dev/null +++ b/pkg/logs/alils/signature.go @@ -0,0 +1,111 @@ +package alils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net/url" + "sort" + "strings" + "time" +) + +// GMT location +var gmtLoc = time.FixedZone("GMT", 0) + +// NowRFC1123 returns now time in RFC1123 format with GMT timezone, +// eg. "Mon, 02 Jan 2006 15:04:05 GMT". +func nowRFC1123() string { + return time.Now().In(gmtLoc).Format(time.RFC1123) +} + +// signature calculates a request's signature digest. +func signature(project *LogProject, method, uri string, + headers map[string]string) (digest string, err error) { + var contentMD5, contentType, date, canoHeaders, canoResource string + var slsHeaderKeys sort.StringSlice + + // SignString = VERB + "\n" + // + CONTENT-MD5 + "\n" + // + CONTENT-TYPE + "\n" + // + DATE + "\n" + // + CanonicalizedSLSHeaders + "\n" + // + CanonicalizedResource + + if val, ok := headers["Content-MD5"]; ok { + contentMD5 = val + } + + if val, ok := headers["Content-Type"]; ok { + contentType = val + } + + date, ok := headers["Date"] + if !ok { + err = fmt.Errorf("Can't find 'Date' header") + return + } + + // Calc CanonicalizedSLSHeaders + slsHeaders := make(map[string]string, len(headers)) + for k, v := range headers { + l := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(l, "x-sls-") { + slsHeaders[l] = strings.TrimSpace(v) + slsHeaderKeys = append(slsHeaderKeys, l) + } + } + + sort.Sort(slsHeaderKeys) + for i, k := range slsHeaderKeys { + canoHeaders += k + ":" + slsHeaders[k] + if i+1 < len(slsHeaderKeys) { + canoHeaders += "\n" + } + } + + // Calc CanonicalizedResource + u, err := url.Parse(uri) + if err != nil { + return + } + + canoResource += url.QueryEscape(u.Path) + if u.RawQuery != "" { + var keys sort.StringSlice + + vals := u.Query() + for k := range vals { + keys = append(keys, k) + } + + sort.Sort(keys) + canoResource += "?" + for i, k := range keys { + if i > 0 { + canoResource += "&" + } + + for _, v := range vals[k] { + canoResource += k + "=" + v + } + } + } + + signStr := method + "\n" + + contentMD5 + "\n" + + contentType + "\n" + + date + "\n" + + canoHeaders + "\n" + + canoResource + + // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) + mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) + _, err = mac.Write([]byte(signStr)) + if err != nil { + return + } + digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return +} diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go new file mode 100644 index 00000000..74c458ab --- /dev/null +++ b/pkg/logs/conn.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "io" + "net" + "time" +) + +// connWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type connWriter struct { + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` +} + +// NewConn create new ConnWrite returning as LoggerInterface. +func NewConn() Logger { + conn := new(connWriter) + conn.Level = LevelTrace + return conn +} + +// Init init connection writer with json config. +// json config only need key "level". +func (c *connWriter) Init(jsonConfig string) error { + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.needToConnectOnMsg() { + err := c.connect() + if err != nil { + return err + } + } + + if c.ReconnectOnMsg { + defer c.innerWriter.Close() + } + + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } + return nil +} + +// Flush implementing method. empty. +func (c *connWriter) Flush() { + +} + +// Destroy destroy connection writer and close tcp listener. +func (c *connWriter) Destroy() { + if c.innerWriter != nil { + c.innerWriter.Close() + } +} + +func (c *connWriter) connect() error { + if c.innerWriter != nil { + c.innerWriter.Close() + c.innerWriter = nil + } + + conn, err := net.Dial(c.Net, c.Addr) + if err != nil { + return err + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + } + + c.innerWriter = conn + c.lg = newLogWriter(conn) + return nil +} + +func (c *connWriter) needToConnectOnMsg() bool { + if c.Reconnect { + return true + } + + if c.innerWriter == nil { + return true + } + + return c.ReconnectOnMsg +} + +func init() { + Register(AdapterConn, NewConn) +} diff --git a/pkg/logs/conn_test.go b/pkg/logs/conn_test.go new file mode 100644 index 00000000..bb377d41 --- /dev/null +++ b/pkg/logs/conn_test.go @@ -0,0 +1,79 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net" + "os" + "testing" +) + +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + +func TestConn(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) + log.Informational("informational") +} + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} diff --git a/pkg/logs/console.go b/pkg/logs/console.go new file mode 100644 index 00000000..3dcaee1d --- /dev/null +++ b/pkg/logs/console.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "os" + "strings" + "time" + + "github.com/shiena/ansicolor" +) + +// brush is a color join function +type brush func(string) string + +// newBrush return a fix color Brush +func newBrush(color string) brush { + pre := "\033[" + reset := "\033[0m" + return func(text string) string { + return pre + color + "m" + text + reset + } +} + +var colors = []brush{ + newBrush("1;37"), // Emergency white + newBrush("1;36"), // Alert cyan + newBrush("1;35"), // Critical magenta + newBrush("1;31"), // Error red + newBrush("1;33"), // Warning yellow + newBrush("1;32"), // Notice green + newBrush("1;34"), // Informational blue + newBrush("1;44"), // Debug Background blue +} + +// consoleWriter implements LoggerInterface and writes messages to terminal. +type consoleWriter struct { + lg *logWriter + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color +} + +// NewConsole create ConsoleWriter returning as LoggerInterface. +func NewConsole() Logger { + cw := &consoleWriter{ + lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), + Level: LevelDebug, + Colorful: true, + } + return cw +} + +// Init init console logger. +// jsonConfig like '{"level":LevelTrace}'. +func (c *consoleWriter) Init(jsonConfig string) error { + if len(jsonConfig) == 0 { + return nil + } + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in console. +func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.Colorful { + msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) + } + c.lg.writeln(when, msg) + return nil +} + +// Destroy implementing method. empty. +func (c *consoleWriter) Destroy() { + +} + +// Flush implementing method. empty. +func (c *consoleWriter) Flush() { + +} + +func init() { + Register(AdapterConsole, NewConsole) +} diff --git a/pkg/logs/console_test.go b/pkg/logs/console_test.go new file mode 100644 index 00000000..4bc45f57 --- /dev/null +++ b/pkg/logs/console_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +// Try each log level in decreasing order of priority. +func testConsoleCalls(bl *BeeLogger) { + bl.Emergency("emergency") + bl.Alert("alert") + bl.Critical("critical") + bl.Error("error") + bl.Warning("warning") + bl.Notice("notice") + bl.Informational("informational") + bl.Debug("debug") +} + +// Test console logging by visually comparing the lines being output with and +// without a log level specification. +func TestConsole(t *testing.T) { + log1 := NewLogger(10000) + log1.EnableFuncCallDepth(true) + log1.SetLogger("console", "") + testConsoleCalls(log1) + + log2 := NewLogger(100) + log2.SetLogger("console", `{"level":3}`) + testConsoleCalls(log2) +} + +// Test console without color +func TestConsoleNoColor(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console", `{"color":false}`) + testConsoleCalls(log) +} + +// Test console async +func TestConsoleAsync(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console") + log.Async() + //log.Close() + testConsoleCalls(log) + for len(log.msgChan) != 0 { + time.Sleep(1 * time.Millisecond) + } +} diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go new file mode 100644 index 00000000..2b7b1710 --- /dev/null +++ b/pkg/logs/es/es.go @@ -0,0 +1,102 @@ +package es + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/elastic/go-elasticsearch/v6" + "github.com/elastic/go-elasticsearch/v6/esapi" + + "github.com/astaxie/beego/logs" +) + +// NewES return a LoggerInterface +func NewES() logs.Logger { + cw := &esLogger{ + Level: logs.LevelDebug, + } + return cw +} + +// esLogger will log msg into ES +// before you using this implementation, +// please import this package +// usually means that you can import this package in your main package +// for example, anonymous: +// import _ "github.com/astaxie/beego/logs/es" +type esLogger struct { + *elasticsearch.Client + DSN string `json:"dsn"` + Level int `json:"level"` +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(jsonconfig string) error { + err := json.Unmarshal([]byte(jsonconfig), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else { + conn, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{el.DSN}, + }) + if err != nil { + return err + } + el.Client = conn + } + return nil +} + +// WriteMsg will write the msg and level into es +func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { + if level > el.Level { + return nil + } + + idx := LogDocument{ + Timestamp: when.Format(time.RFC3339), + Msg: msg, + } + + body, err := json.Marshal(idx) + if err != nil { + return err + } + req := esapi.IndexRequest{ + Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), + DocumentType: "logs", + Body: strings.NewReader(string(body)), + } + _, err = req.Do(context.Background(), el.Client) + return err +} + +// Destroy is a empty method +func (el *esLogger) Destroy() { +} + +// Flush is a empty method +func (el *esLogger) Flush() { + +} + +type LogDocument struct { + Timestamp string `json:"timestamp"` + Msg string `json:"msg"` +} + +func init() { + logs.Register(logs.AdapterEs, NewES) +} diff --git a/pkg/logs/file.go b/pkg/logs/file.go new file mode 100644 index 00000000..222db989 --- /dev/null +++ b/pkg/logs/file.go @@ -0,0 +1,409 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// fileLogWriter implements LoggerInterface. +// It writes messages by lines limit, file size limit, or time frequency. +type fileLogWriter struct { + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + // The opened file + Filename string `json:"filename"` + fileWriter *os.File + + // Rotate at line + MaxLines int `json:"maxlines"` + maxLinesCurLines int + + MaxFiles int `json:"maxfiles"` + MaxFilesCurFiles int + + // Rotate at size + MaxSize int `json:"maxsize"` + maxSizeCurSize int + + // Rotate daily + Daily bool `json:"daily"` + MaxDays int64 `json:"maxdays"` + dailyOpenDate int + dailyOpenTime time.Time + + // Rotate hourly + Hourly bool `json:"hourly"` + MaxHours int64 `json:"maxhours"` + hourlyOpenDate int + hourlyOpenTime time.Time + + Rotate bool `json:"rotate"` + + Level int `json:"level"` + + Perm string `json:"perm"` + + RotatePerm string `json:"rotateperm"` + + fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix +} + +// newFileWriter create a FileLogWriter returning as LoggerInterface. +func newFileWriter() Logger { + w := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Hourly: false, + MaxHours: 168, + Rotate: true, + RotatePerm: "0440", + Level: LevelTrace, + Perm: "0660", + MaxLines: 10000000, + MaxFiles: 999, + MaxSize: 1 << 28, + } + return w +} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":10000, +// "maxsize":1024, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":"0600" +// } +func (w *fileLogWriter) Init(jsonConfig string) error { + err := json.Unmarshal([]byte(jsonConfig), w) + if err != nil { + return err + } + if len(w.Filename) == 0 { + return errors.New("jsonconfig must have filename") + } + w.suffix = filepath.Ext(w.Filename) + w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) + if w.suffix == "" { + w.suffix = ".log" + } + err = w.startLogger() + return err +} + +// start file logger. create log file and set to locker-inside file writer. +func (w *fileLogWriter) startLogger() error { + file, err := w.createLogFile() + if err != nil { + return err + } + if w.fileWriter != nil { + w.fileWriter.Close() + } + w.fileWriter = file + return w.initFd() +} + +func (w *fileLogWriter) needRotateDaily(size int, day int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Daily && day != w.dailyOpenDate) +} + +func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Hourly && hour != w.hourlyOpenDate) + +} + +// WriteMsg write logger message into file. +func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > w.Level { + return nil + } + hd, d, h := formatTimeHeader(when) + msg = string(hd) + msg + "\n" + if w.Rotate { + w.RLock() + if w.needRotateHourly(len(msg), h) { + w.RUnlock() + w.Lock() + if w.needRotateHourly(len(msg), h) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else if w.needRotateDaily(len(msg), d) { + w.RUnlock() + w.Lock() + if w.needRotateDaily(len(msg), d) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else { + w.RUnlock() + } + } + + w.Lock() + _, err := w.fileWriter.Write([]byte(msg)) + if err == nil { + w.maxLinesCurLines++ + w.maxSizeCurSize += len(msg) + } + w.Unlock() + return err +} + +func (w *fileLogWriter) createLogFile() (*os.File, error) { + // Open the log file + perm, err := strconv.ParseInt(w.Perm, 8, 64) + if err != nil { + return nil, err + } + + filepath := path.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(perm)) + + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) + if err == nil { + // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask + os.Chmod(w.Filename, os.FileMode(perm)) + } + return fd, err +} + +func (w *fileLogWriter) initFd() error { + fd := w.fileWriter + fInfo, err := fd.Stat() + if err != nil { + return fmt.Errorf("get stat err: %s", err) + } + w.maxSizeCurSize = int(fInfo.Size()) + w.dailyOpenTime = time.Now() + w.dailyOpenDate = w.dailyOpenTime.Day() + w.hourlyOpenTime = time.Now() + w.hourlyOpenDate = w.hourlyOpenTime.Hour() + w.maxLinesCurLines = 0 + if w.Hourly { + go w.hourlyRotate(w.hourlyOpenTime) + } else if w.Daily { + go w.dailyRotate(w.dailyOpenTime) + } + if fInfo.Size() > 0 && w.MaxLines > 0 { + count, err := w.lines() + if err != nil { + return err + } + w.maxLinesCurLines = count + } + return nil +} + +func (w *fileLogWriter) dailyRotate(openTime time.Time) { + y, m, d := openTime.Add(24 * time.Hour).Date() + nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateDaily(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) hourlyRotate(openTime time.Time) { + y, m, d := openTime.Add(1 * time.Hour).Date() + h, _, _ := openTime.Add(1 * time.Hour).Clock() + nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateHourly(0, time.Now().Hour()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) lines() (int, error) { + fd, err := os.Open(w.Filename) + if err != nil { + return 0, err + } + defer fd.Close() + + buf := make([]byte, 32768) // 32k + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := fd.Read(buf) + if err != nil && err != io.EOF { + return count, err + } + + count += bytes.Count(buf[:c], lineSep) + + if err == io.EOF { + break + } + } + + return count, nil +} + +// DoRotate means it need to write file in new file. +// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) +func (w *fileLogWriter) doRotate(logTime time.Time) error { + // file exists + // Find the next available number + num := w.MaxFilesCurFiles + 1 + fName := "" + format := "" + var openTime time.Time + rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) + if err != nil { + return err + } + + _, err = os.Lstat(w.Filename) + if err != nil { + //even if the file is not exist or other ,we should RESTART the logger + goto RESTART_LOGGER + } + + if w.Hourly { + format = "2006010215" + openTime = w.hourlyOpenTime + } else if w.Daily { + format = "2006-01-02" + openTime = w.dailyOpenTime + } + + // only when one of them be setted, then the file would be splited + if w.MaxLines > 0 || w.MaxSize > 0 { + for ; err == nil && num <= w.MaxFiles; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + } + } else { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + w.MaxFilesCurFiles = num + } + + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) + } + + // close fileWriter before rename + w.fileWriter.Close() + + // Rename the file to its new found name + // even if occurs error,we MUST guarantee to restart new logger + err = os.Rename(w.Filename, fName) + if err != nil { + goto RESTART_LOGGER + } + + err = os.Chmod(fName, os.FileMode